もう迷わない!PyTorch fx.Graph.call_module()のエラーと解決策

2025-05-31

torch.fx.Graph.call_module() とは

torch.fx では、PyTorch モデル (nn.Module) の forward メソッドが実行される際に、その内部の演算がノードという形でグラフに記録されます。このノードにはいくつかの種類があり、その中の一つが call_module ノードです。

call_module ノードは、ある nn.Moduleforward メソッドが、その子モジュール(サブモジュール)を呼び出していることを表します。

具体的には、以下のような情報がノードとして記録されます。

  • kwargs: 呼び出しに使われるキーワード引数(辞書)。
  • args: 呼び出しに使われる位置引数(タプル)。
  • target: 呼び出されているモジュールの完全修飾名(例: self.linear など)。
  • name: ノードの一意な名前です。
  • opcode: ノードの種類を表します。call_module の場合は 'call_module' となります。

簡単な例で考えてみましょう。

import torch
import torch.nn as nn
import torch.fx

class MySubModule(nn.Module):
    def forward(self, x):
        return x * 2

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.sub_module = MySubModule()

    def forward(self, x):
        x = self.linear(x)
        x = self.sub_module(x)
        return x

# モデルをトレースしてグラフを生成
m = MyModule()
traced_graph = torch.fx.symbolic_trace(m)

# グラフのノードを出力
for node in traced_graph.graph.nodes:
    print(f"opcode: {node.op}, name: {node.name}, target: {node.target}, args: {node.args}, kwargs: {node.kwargs}")

このコードを実行すると、以下のような出力の一部が得られるでしょう(完全な出力ではありませんが、関連する部分を抜粋します)。

opcode: placeholder, name: x, target: x, args: (), kwargs: {}
opcode: call_module, name: linear, target: linear, args: (x,), kwargs: {}
opcode: call_module, name: sub_module, target: sub_module, args: (linear,), kwargs: {}
opcode: output, name: output, target: output, args: (sub_module,), kwargs: {}

ここで注目すべきは、opcode: call_module の行です。

  • name: sub_module, target: sub_module: これは MyModuleself.sub_module が呼び出されていることを示します。
  • name: linear, target: linear: これは MyModuleself.linear が呼び出されていることを示します。

torch.fx を使用してモデルを変換したり最適化したりする際に、call_module ノードは非常に重要な役割を果たします。

  1. モデルの構造の理解: グラフを走査することで、どのサブモジュールが、どのような順番で、どのような引数で呼び出されているかを把握できます。
  2. 変換のターゲット: 特定の種類のモジュール(例: nn.Conv2dnn.BatchNorm2d)の呼び出しを特定し、それらを別の実装に置き換えたり、結合したりする(例: Conv-BN融合)といった変換を行う際に、call_module ノードをターゲットとします。
  3. 部分的な最適化: モデル全体ではなく、特定のサブモジュールに対して量子化や枝刈りなどの最適化を適用する場合、call_module ノードを使ってそのサブモジュールを識別します。


torch.fx.Graph.call_module() の一般的なエラーとトラブルシューティング

ModuleNotFoundError: No module named 'torch.fx'

これは torch.fx を使用する上で最も基本的なエラーです。

  • トラブルシューティング
    • PyTorch のバージョンを確認します。torch.fx は PyTorch 1.8.0 以降で導入されました。それ以前のバージョンを使用している場合は、PyTorch をアップグレードする必要があります。
      pip install torch torchvision torchaudio --upgrade
      
    • import torch.fx または from torch.fx import symbolic_trace のように正しくインポートされているか確認します。
  • エラーの原因
    torch.fx モジュールが見つからない。

torch.fx.symbolic_trace が意図したとおりに動作しない(Graph に call_module ノードが欠落している、または不正確)

これは、モデルの forward メソッドにトレースできない操作が含まれている場合に起こりやすいです。

  • トラブルシューティング
    • トレース可能なコードに変換
      • 動的な制御フローを避けるか、テンソルの形状に依存しないように修正します。可能であれば、torch.where のようなテンソル操作に置き換えることを検討します。
      • リストや辞書を介したテンソルの受け渡しを避け、テンソル自体を直接操作するようにします。
      • サポートされていない操作がある場合、それらをカスタムの nn.Module にラップし、Traceris_leaf_module メソッドをオーバーライドして、そのカスタムモジュールを「葉」として扱うように指定することで、その内部のトレースをスキップできます。
        import torch
        import torch.nn as nn
        from torch.fx import Tracer, symbolic_trace, GraphModule
        
        class MyUnfriendlyOp(nn.Module):
            def forward(self, x):
                # トレースしにくい操作 (例: 動的なリスト生成)
                return torch.stack([x, x * 2])
        
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(10, 10)
                self.unfriendly = MyUnfriendlyOp()
        
            def forward(self, x):
                x = self.linear(x)
                x = self.unfriendly(x)
                return x
        
        # MyUnfriendlyOp を葉モジュールとして扱うカスタム Tracer
        class CustomTracer(Tracer):
            def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
                if isinstance(m, MyUnfriendlyOp):
                    return True  # このモジュールは内部をトレースしない
                return super().is_leaf_module(m, module_qualified_name)
        
        m = MyModule()
        traced_graph_module = symbolic_trace(m, tracer_class=CustomTracer)
        
        # この場合、MyUnfriendlyOp は call_module ノードとして現れる
        for node in traced_graph_module.graph.nodes:
            print(node)
        
    • print デバッグ
      symbolic_trace の前後でモデルの forward メソッドに print ステートメントを追加し、トレースがどこで中断されているか、またはどの値が期待と異なるかを確認します。
  • エラーの原因
    • Python の動的な制御フロー
      if/else、ループ (for, while) など、入力テンソルの値に依存する制御フローは、静的なグラフとして表現することができません。call_module が条件分岐の内側にある場合、そのモジュールがトレースされないことがあります。
    • Python のネイティブなデータ構造の操作
      リスト、辞書などの Python オブジェクトを直接操作する(特にテンソル以外のデータを扱う)と、トレースが中断されることがあります。
    • サポートされていないPyTorchの関数/オペレーション
      torch.fx はほとんどの torch.nn モジュールや torch の関数をサポートしていますが、一部の特殊な操作(例: torch.arange でサイズが動的に決定される場合など)はトレースできないことがあります。
    • 外部ライブラリの呼び出し
      NumPyなどのPyTorch以外のライブラリの関数を直接呼び出すと、トレースが中断されます。

AttributeError: 'GraphModule' object has no attribute 'xxx'

グラフ変換後に、オリジナルの nn.Module の属性が GraphModule に引き継がれていない場合に発生することがあります。

  • トラブルシューティング
    • GraphModule への属性の追加
      必要な属性が GraphModule に存在しない場合、手動で追加するか、変換ロジックで考慮する必要があります。
      import torch.nn as nn
      from torch.fx import symbolic_trace, GraphModule
      
      class MyModule(nn.Module):
          def __init__(self):
              super().__init__()
              self.linear = nn.Linear(10, 1)
              self.custom_value = 42 # forward で使わない
      
          def forward(self, x):
              return self.linear(x)
      
      m = MyModule()
      traced_gm = symbolic_trace(m)
      
      # traced_gm.custom_value は存在しないため AttributeError になる
      # print(traced_gm.custom_value)
      
      # 必要な場合は手動で追加
      traced_gm.custom_value = m.custom_value
      print(traced_gm.custom_value)
      
    • サブモジュールの管理
      call_module ノードを操作する際、そのノードが参照するサブモジュールが GraphModule_modules ディクショナリに適切に登録されていることを確認してください。GraphModule.add_submodule()GraphModule.delete_submodule() を適切に使用します。
  • エラーの原因
    torch.fx.symbolic_trace は、モデルの forward メソッドを通じてアクセスされる属性のみをグラフに含めます。例えば、__init__ で定義されているが forward で使われていない属性や、トレース後に手動で追加した属性は、生成された GraphModule には存在しません。特に、call_module ノードを削除したり、その target を変更したりすると、参照が壊れる可能性があります。

変換後の GraphModule が元のモデルと同じ出力を生成しない

これはデバッグが難しい場合が多いですが、call_module ノードの引数や出力の不一致が原因となることがあります。

  • トラブルシューティング
    • graph.lint() の使用
      グラフを変更するたびに、またはデバッグの際には graph.lint() を呼び出す習慣をつけましょう。これにより、多くの不正なグラフ構造の変更を早期に検出できます。
    • ノードの入出力の確認
      print(node.args)print(node.kwargs) を使って、各 call_module ノードに渡されている引数を詳細に確認します。特に、変更後のグラフでそれが正しいノードを参照しているかを確認します。
    • 中間出力の比較
      変換前と変換後のモデルで、特定の中間層の出力を比較します。これにより、どこで出力が乖離し始めたか特定しやすくなります。
    • 小さな単位で変更しテスト
      一度に大きな変更を加えるのではなく、小さな変換を適用してはテストを繰り返すことで、問題の切り分けが容易になります。
  • エラーの原因
    • ノードの引数/キーワード引数の不正確な変更
      call_module ノードの argskwargs を変更した際、その変更が元のモジュールの期待する入力形式と異なる場合。
    • 依存関係の誤り
      ノード間の依存関係(node.argsnode.kwargs が他のノードの出力を参照している場合)を正しく管理できていない。
    • グラフの健全性チェックの怠り
      グラフを変更した後、graph.lint() を呼び出して、グラフが有効な状態であることを確認していない。

TypeError: forward() missing N required positional arguments / TypeError: got an unexpected keyword argument 'xxx'

これは、call_module ノードの引数が、呼び出されるサブモジュールの forward メソッドのシグネチャと一致しない場合に発生します。

  • トラブルシューティング
    • シグネチャの確認
      呼び出されるサブモジュールの forward メソッドの正確なシグネチャ(引数の名前、順序、デフォルト値など)を確認します。
    • ノードの引数の修正
      call_module ノードの argskwargs を、サブモジュールの forward メソッドに合うように調整します。
    • Python の inspect モジュール
      inspect.signature を使用して、モジュールの forward メソッドのシグネチャをプログラム的に取得し、それに基づいてノードの引数を生成することができます。
  • エラーの原因
    • call_module ノードの argskwargs を手動で操作した際に、引数の数や名前がサブモジュールの forward メソッドと合わなくなった。
    • オリジナルのモジュールが複雑な forward シグネチャ(例: *args, **kwargs を多用)を持っており、symbolic_trace がそれを正確に再現できなかった。
  • GraphModule の code プロパティの確認
    変換後の GraphModulecode プロパティ(print(traced_gm.code))を見ると、生成された Python コードを確認できます。これにより、意図しない挙動になっている箇所を見つけやすくなります。
  • 段階的なアプローチ
    複雑なモデルの場合、一度に全体をトレース・変換しようとせず、小さなサブモジュールごとに試したり、段階的に変換を適用したりすることで、問題の原因を特定しやすくなります。
  • PyTorch フォーラムやGitHub Issuesの検索
    遭遇したエラーメッセージや状況は、他のユーザーも経験している可能性があります。フォーラムやGitHubで検索することで、解決策が見つかることがあります。
  • PyTorch と torch.fx のドキュメント参照
    公式ドキュメントは最も正確で最新の情報源です。特に torch.fx の章は、内部動作を理解するために非常に役立ちます。
  • 最小限の再現コード
    エラーが発生した場合は、問題を再現できる最小限のコードを作成するように努めます。これにより、問題を特定しやすくなります。


以下の例では、モデルのトレース、call_module ノードの識別、およびそのノードを操作する基本的な方法を示します。

例1: モデルのトレースと call_module ノードの識別

この例では、シンプルなモデルをトレースし、生成されたグラフから call_module ノードを抽出し、その情報を表示します。

import torch
import torch.nn as nn
import torch.fx

# 1. シンプルなPyTorchモデルの定義
class SubModuleA(nn.Module):
    def forward(self, x):
        print("Executing SubModuleA")
        return x + 1

class SubModuleB(nn.Module):
    def forward(self, x):
        print("Executing SubModuleB")
        return x * 2

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.sub_a = SubModuleA()
        self.sub_b = SubModuleB()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.sub_a(x)
        x = self.sub_b(x)
        x = self.linear2(x)
        return x

# 2. モデルのインスタンス化とダミー入力の準備
model = MyModel()
dummy_input = torch.randn(1, 10)

# 3. モデルをトレースしてグラフを生成
# symbolic_trace はモデルの forward メソッドのシンボリック実行を行い、グラフを構築します
traced_model = torch.fx.symbolic_trace(model)

print("--- 生成されたグラフのノード ---")
# 4. グラフのノードをイテレートし、call_module ノードを識別
for node in traced_model.graph.nodes:
    # node.op はノードの種類(opcode)を表します
    if node.op == 'call_module':
        print(f"  Call Module Node:")
        print(f"    Name: {node.name}")           # グラフ内のノードの一意な名前
        print(f"    Target: {node.target}")       # 呼び出されるサブモジュールの名前 (例: 'linear1', 'sub_a')
        print(f"    Args: {node.args}")           # このノードへの入力引数 (タプル)
        print(f"    Kwargs: {node.kwargs}")       # このノードへのキーワード引数 (辞書)
        print(f"    Module Instance: {getattr(traced_model, node.target)}") # 実際のモジュールインスタンス
    else:
        print(f"  Other Node: {node.op} - {node.name}")

print("\n--- トレースされたモデルのコード ---")
# トレースされたモデルの内部で生成されたPythonコードを表示
print(traced_model.code)

print("\n--- トレースされたモデルの実行テスト ---")
# トレースされたモデルは通常の nn.Module と同様に実行できます
output = traced_model(dummy_input)
print(f"出力形状: {output.shape}")

解説

  • traced_model.code: トレースによって内部的に生成された Python コードを表示します。これは、torch.fx がどのようにモデルを再構築したかを理解するのに非常に役立ちます。
  • node.target: この属性は、MyModel__init__ で定義されたサブモジュールの名前(例: self.linear1 なら 'linear1')に対応します。これは、トレースされた GraphModule の属性として、元のモジュールのインスタンスが保持されています(getattr(traced_model, node.target) でアクセス可能)。
  • node.op == 'call_module': ノードの種類が 'call_module' であるかどうかをチェックしています。これは、nn.Module の子モジュールが呼び出されたことを意味します。
  • symbolic_trace(model): これが torch.fx の中核となる関数で、modelforward メソッドをシンボリックに実行し、PyTorch の演算を Graph オブジェクト内の Node のコレクションに変換します。

この例では、call_module ノードを見つけて、そのノードが参照するサブモジュールを別のものに置き換える方法を示します。ここでは、SubModuleASubModuleC に置き換えます。

import torch
import torch.nn as nn
import torch.fx
from torch.fx.graph import Graph, Node
from torch.fx import symbolic_trace, GraphModule

# 元のモデル定義 (例1と同じ)
class SubModuleA(nn.Module):
    def forward(self, x):
        print("Executing SubModuleA")
        return x + 1

class SubModuleB(nn.Module):
    def forward(self, x):
        print("Executing SubModuleB")
        return x * 2

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.sub_a = SubModuleA()
        self.sub_b = SubModuleB()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.sub_a(x) # ここを置き換える
        x = self.sub_b(x)
        x = self.linear2(x)
        return x

# 新しいサブモジュール
class SubModuleC(nn.Module):
    def forward(self, x):
        print("Executing SubModuleC (REPLACED!)")
        return x - 5 # 演算を変更

# 1. モデルをトレース
model = MyModel()
dummy_input = torch.randn(1, 10)
traced_model = symbolic_trace(model)

print("--- 変換前のグラフのコード ---")
print(traced_model.code)

# 2. グラフをイテレートし、特定の call_module ノードを見つける
new_graph = Graph()
env = {} # 古いノードと新しいノードのマッピングを保持

for node in traced_model.graph.nodes:
    # 既存のノードを新しいグラフにコピー
    # args と kwargs の参照を env に従って更新する
    new_node = new_graph.node_copy(node, lambda x: env[x])
    env[node] = new_node

    # 'sub_a' というターゲットを持つ call_module ノードを探す
    if new_node.op == 'call_module' and new_node.target == 'sub_a':
        print(f"\n--- 'sub_a' ノードを置き換えます ---")
        # 新しいサブモジュールを GraphModule に追加
        # ここでは新しい名前 'replaced_sub_a' で追加
        traced_model.add_submodule('replaced_sub_a', SubModuleC())
        
        # ノードのターゲットを新しいサブモジュール名に変更
        new_node.target = 'replaced_sub_a'
        print(f"  変更後ターゲット: {new_node.target}")

# 3. 新しいグラフで GraphModule を再構築
# オリジナルの traced_model._modules を新しい GraphModule にコピー
# add_submodule で追加されたモジュールも含まれる
new_traced_model = GraphModule(traced_model, new_graph) 

print("\n--- 変換後のグラフのコード ---")
print(new_traced_model.code)

print("\n--- 変換前と変換後のモデルの実行結果比較 ---")

# 変換前のモデルを実行
print("\n[元のモデルの実行]")
original_output = model(dummy_input)
print(f"元のモデルの出力: {original_output.item()}")

# 変換後のモデルを実行
print("\n[変換後のモデルの実行]")
modified_output = new_traced_model(dummy_input)
print(f"変換後のモデルの出力: {modified_output.item()}")

# 比較のために、手動で計算してみる
# x_init = dummy_input
# x_linear1 = model.linear1(x_init)
# x_relu = model.relu(x_linear1)
# x_sub_a = model.sub_a(x_relu) # 元のパス
# x_sub_c = SubModuleC()(x_relu) # 置き換え後のパス
# x_sub_b = model.sub_b(x_sub_a or x_sub_c)
# x_linear2 = model.linear2(x_sub_b)
  1. グラフのコピーと操作
    • new_graph = Graph()env = {} を使って、新しいグラフを構築しながら元のグラフのノードをコピーしています。これは、グラフを安全に操作するための一般的なパターンです。node_copy はノードを新しいグラフにコピーし、env を使って古いノードの参照を新しいノードにマッピングします。
    • new_node.op == 'call_module' and new_node.target == 'sub_a': これにより、置き換えたい特定の call_module ノードを識別します。
    • traced_model.add_submodule('replaced_sub_a', SubModuleC()): 新しい SubModuleC のインスタンスを、トレースされた GraphModule のサブモジュールとして追加します。この操作により、GraphModule はこの新しいモジュールを管理できるようになります。重要なのは、GraphModule にはサブモジュールの実際のインスタンスがディクショナリ形式で保持されている点です。
    • new_node.target = 'replaced_sub_a': call_module ノードの target 属性を新しいモジュールの名前に変更します。これにより、このノードが実行された際に SubModuleC が呼び出されるようになります。
  2. GraphModule の再構築
    • new_traced_model = GraphModule(traced_model, new_graph): 変更した new_graph を使用して新しい GraphModule インスタンスを作成します。この際、traced_model (元の GraphModule) を最初の引数として渡すことで、既存のサブモジュール (linear1, relu など、そして add_submodule で追加した replaced_sub_a) が新しい GraphModule に引き継がれます。
  3. 実行結果の比較
    元のモデルと変換後のモデルを実行し、出力が期待通りに変化したことを確認します。SubModuleAx + 1 であったのに対し、SubModuleCx - 5 であるため、出力値が異なるはずです。


torch.compile (推奨)

call_module() との関係
torch.compile は内部で FX グラフ(したがって call_module ノードも含む)を生成・操作しますが、ユーザーが直接 call_module() ノードを操作する必要はありません。torch.compile が自動的にモデルを解析し、最適化されたグラフを構築します。

利点

  • Python フォールバック
    トレースできない部分があっても、自動的に Python 実行にフォールバックするため、エラーになりにくいです。
  • 広範なサポート
    ほとんどの PyTorch モデルやデータ依存の制御フロー(if/else など)を処理できます。
  • 簡単な使用方法
    ほとんどの場合、モデルや関数を torch.compile() でラップするだけで済みます。
  • 高いパフォーマンス向上
    PyTorch モデルの実行速度を劇的に向上させることが期待できます。

欠点

  • デバッグが難しい場合があるかもしれません。
  • モデルの内部構造を直接操作したい場合には向いていません。

使用例

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

model = MyModel()
compiled_model = torch.compile(model) # これだけ!

dummy_input = torch.randn(1, 10)
output_original = model(dummy_input)
output_compiled = compiled_model(dummy_input)

print(f"元のモデルの出力: {output_original}")
print(f"コンパイル済みモデルの出力: {output_compiled}")

torch.jit.script / torch.jit.trace (TorchScript)

  • torch.jit.trace: 実際の入力データを使ってモデルの実行パスを記録 (トレース) して IR を構築します。データ依存の制御フローは、トレースされたパスのみが記録されます。
  • torch.jit.script: Python のサブセット (TorchScript 言語) としてモデルコードを静的に解析し、IR を構築します。制御フロー (if/else やループ) もキャプチャできます。

call_module() との関係
TorchScript もモデルのグラフ表現を生成しますが、そのIRは FX グラフとは異なります。TorchScript の IR はより低レベルで、call_module のような高レベルな概念ではなく、プリミティブな演算に分解される傾向があります。そのため、モジュールレベルでの詳細なグラフ操作には不向きです。

利点

  • 制御フローのキャプチャ (script)
    torch.jit.script はデータ非依存の制御フローを適切に扱えます。
  • 最適化
    C++ バックエンドでの実行により、パフォーマンスの向上が期待できます。
  • デプロイメント
    モデルを Python インタープリタなしで実行できる形式に変換できます。

欠点

  • trace の制限
    torch.jit.trace はデータ依存の制御フローを正しくキャプチャできません(トレース時のパスしか記録されない)。
  • デバッグの複雑さ
    エラーメッセージが分かりにくく、デバッグが難しい場合があります。
  • Python の制限
    TorchScript は Python のサブセットであり、すべての Python 構文やデータ構造をサポートしているわけではありません。

使用例 (trace の場合)

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

model = MyModel()
dummy_input = torch.randn(1, 10)

# モデルをトレース
traced_model = torch.jit.trace(model, dummy_input)

print("--- トレースされたモデル (TorchScript) ---")
print(traced_model.graph) # TorchScript の IR を表示

output_traced = traced_model(dummy_input)
print(f"トレースされたモデルの出力: {output_traced}")

# モデルを保存・ロードすることも可能
# traced_model.save("my_model.pt")
# loaded_model = torch.jit.load("my_model.pt")

使用例 (script の場合)

import torch
import torch.nn as nn

@torch.jit.script # @torch.jit.script アノテーションを付ける
class MyScriptableModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        # TorchScript がサポートする制御フロー
        if x.mean() > 0:
            return self.linear2(self.relu(self.linear1(x)))
        else:
            return self.linear2(self.linear1(x)) * -1

model = MyScriptableModel()
dummy_input = torch.randn(1, 10)

# スクリプト化されたモデルは直接呼び出せる
scripted_model = model # アノテーションにより、インスタンス化時点でスクリプト化される

print("--- スクリプト化されたモデル (TorchScript) ---")
print(scripted_model.graph) # TorchScript の IR を表示

output_scripted = scripted_model(dummy_input)
print(f"スクリプト化されたモデルの出力: {output_scripted}")

torch.export (実験的/高度なユースケース向け)

call_module() との関係
torch.export も内部的に TorchDynamo と FX を使用しますが、生成されるグラフはさらに低レベル(ATen オペレータレベル)に分解される傾向があります。これは call_module のような高レベルなモジュール呼び出しではなく、よりプリミティブなテンソル演算のシーケンスになります。

利点

  • 正確なメタデータの追跡
    テンソルの形状に関する条件分岐など、より細かいメタデータを扱えます。
  • 完全なグラフキャプチャ
    Untraceable なコードがあるとエラーになるため、完全なグラフが生成されていることが保証されます。
  • 移植性
    生成されたグラフは、より多くのランタイム環境や言語で利用できる可能性が高いです。

欠点

  • デプロイメントパイプラインのより深い部分に組み込むことを意図しています。
  • 完全にトレース可能なコードを必要とし、複雑な Python のセマンティクスを多く含むモデルでは、コードの書き換えが必要になる場合があります。
  • まだ実験的な機能であり、変更される可能性があります。

使用例

import torch
import torch.nn as nn
from torch.export import export, ExportedProgram

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

model = MyModel()
dummy_input = torch.randn(1, 10)

# モデルをエクスポート
# dynamic_shapes を使用して動的な入力形状をサポートすることも可能
exported_program: ExportedProgram = export(model, (dummy_input,))

print("--- エクスポートされたプログラム (ATen レベルのグラフ) ---")
# エクスポートされたプログラムのグラフは、さらに低レベルになる
print(exported_program.graph_module.graph)

output_exported = exported_program(dummy_input)
print(f"エクスポートされたモデルの出力: {output_exported}")

# エクスポートされたプログラムはシリアライズ可能
# torch.export.save(exported_program, "exported_model.ep")
# loaded_program = torch.export.load("exported_model.ep")

call_module() との関係
ONNX は call_module のような PyTorch 特有のモジュール呼び出しの概念を持ちません。PyTorch モデルが ONNX にエクスポートされる際、nn.Module の呼び出しは、ONNX のオペレータセットで表現される低レベルな計算グラフに変換されます。

利点

  • 多くのツールとライブラリのサポート
    ONNX は広く採用されており、多くのツールと互換性があります。
  • デプロイメントの柔軟性
    ONNX Runtime などの最適化されたバックエンドで推論を実行できます。
  • フレームワーク間の相互運用性
    異なる ML フレームワーク間でモデルを共有できます。

欠点

  • 手動でのデバッグ
    エクスポートに失敗したり、期待通りの結果が得られない場合、ONNX グラフを直接デバッグする必要があります。
  • Python の動的な機能の制限
    torch.jit.trace と同様に、データ依存の制御フローや複雑な Python ロジックは正しくエクスポートされない場合があります。

使用例

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

model = MyModel()
dummy_input = torch.randn(1, 10)

# ONNXへのエクスポート
onnx_path = "my_model.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"}, # バッチサイズを動的にする
        "output": {0: "batch_size"}
    }
)

print(f"モデルが {onnx_path} にエクスポートされました。")

# ONNX Runtime での実行例 (onnxruntime をインストールする必要あり)
import onnxruntime
import numpy as np

sess = onnxruntime.InferenceSession(onnx_path)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

# NumPy 配列として入力を用意
dummy_input_np = dummy_input.detach().numpy()
output_onnx = sess.run([output_name], {input_name: dummy_input_np})[0]

print(f"ONNX Runtime の出力: {output_onnx}")
print(f"元のPyTorchモデルの出力: {model(dummy_input).detach().numpy()}")

torch.fx.Graph.call_module() は PyTorch モデルのグラフ変換における強力な低レベル API です。しかし、ほとんどのユーザーは直接この API を操作するのではなく、以下のような高レベルな代替手段を利用することが多いです。

  • より深いレベルでグラフを操作し、高度なコンパイラ最適化やカスタムの変換を実装したい場合
    torch.fx を直接使用しますが、これはより専門的な知識を必要とします。
  • フレームワーク間での相互運用性や汎用的な推論エンジンへのエクスポートが必要な場合
    ONNX が適しています。
  • Python 以外の環境へのデプロイや、モデルのシリアライズを目的とする場合
    torch.jit.script (制御フローがある場合) や torch.jit.trace (固定パスの場合) が適しています。
  • 簡単なパフォーマンス向上やデプロイを考慮する場合
    torch.compile が最も推奨されるアプローチです。