PyTorch FX Interpreter活用術:グラフのカスタマイズと応用例
簡単に言うと、torch.fx
は、PyTorchのnn.Module
インスタンスを変換するためのツールキットです。これには、以下の3つの主要なコンポーネントがあります。
- Symbolic Tracer (シンボリックトレーサー): PyTorchモデルの
forward
メソッドを「追跡」し、実行される操作(演算、モジュール呼び出し、Pythonの組み込み関数など)を記録します。 - Intermediate Representation (中間表現): 追跡された操作は、
torch.fx.Graph
というデータ構造に変換されます。このグラフは、操作がノードとして、その依存関係がエッジとして表現されます。 - Python Code Generation (Pythonコード生成): グラフから新しいPythonコードを生成し、
torch.fx.GraphModule
として元のnn.Module
と同様に実行できるモジュールを作成します。
torch.fx.Interpreter.run()
の役割
torch.fx.Interpreter
は、この中間表現であるGraph
を「解釈(interpret)」し、そのグラフに記述された計算を実際に実行するクラスです。そして、run()
メソッドは、その解釈プロセスを開始する役割を担います。
具体的には、以下のようなことを行います。
- 引数の処理:
run()
メソッドには、グラフのプレースホルダーノード(入力)に対応する引数を渡します。これらの引数は、グラフの実行中に適切なノードに供給されます。 - 値の伝播: 各ノードの計算結果は、次のノードの入力として渡されます。これは、通常のPyTorchモデルのフォワードパスがどのように実行されるかと似ています。
- グラフの実行:
Interpreter
は、GraphModule
が持つグラフ内のノードを順番にたどります。各ノードは、PyTorchの演算(torch.add
など)、モジュール呼び出し(self.linear(x)
など)、またはPythonの組み込み関数(len()
など)に対応しています。
なぜInterpreter.run()
を使うのか?
直接GraphModule
を呼び出すだけでもグラフを実行できるのに、なぜInterpreter
を使うのでしょうか?
主な理由は以下の通りです。
-
カスタマイズと拡張性:
Interpreter
クラスを継承することで、グラフの実行方法を細かく制御したり、カスタムの動作を追加したりできます。例えば、以下のようなことができます。- プロファイリング: 各ノードの実行時間やメモリ使用量を測定し、モデルのパフォーマンスを分析する。
- デバッグ: グラフの特定の部分で中間値を出力したり、条件付きで実行を中断したりする。
- 部分的な実行: グラフの一部だけを実行したり、特定のノードの結果を事前に設定したりする。
- カスタムロジックの注入: 特定の種類の操作が実行されるときに、独自のロジックを挟み込む。
-
ノードごとの処理:
Interpreter
は、グラフ内の個々のノードを処理するためのメソッド(例:run_node
、call_function
、call_module
、call_method
)を提供します。これらのメソッドをオーバーライドすることで、ノードレベルでの挙動を変更できます。
import torch
import torch.nn as nn
from torch.fx import symbolic_trace, Interpreter
# 簡単なモデルの定義
class MyModule(nn.Module):
def forward(self, x):
return x + 1 * 2
# モデルのシンボリックトレース
traced_module = symbolic_trace(MyModule())
# Interpreterインスタンスの作成
interpreter = Interpreter(traced_module)
# run()メソッドでグラフを実行
# グラフの入力に10を渡す
result = interpreter.run(10)
print(f"Graph execution result: {result}") # 出力: Graph execution result: 22
一般的なエラーとトラブルシューティング
TypeError: 'NoneType' object is not callable や AttributeError: 'NoneType' object has no attribute 'method_name'
原因
これは、トレースされたグラフ内で、期待される入力や中間結果がNone
になっているにもかかわらず、そのNone
に対して操作(呼び出しや属性アクセス)を行おうとしたときに発生します。
考えられる原因
- バグのあるカスタム
Interpreter
:Interpreter
を継承してカスタムロジックを追加している場合、そのロジックにバグがあり、None
を誤って伝播させてしまうことがあります。 - 初期値の問題
グラフの入力(placeholder
ノード)に対応するrun()
への引数が正しくないか、期待される型ではない場合。 - トレースの不完全性
symbolic_trace
がモデルの特定のパスを完全に追跡できなかった場合、グラフ内のノードの出力がNone
になることがあります。これは、Pythonの制御フロー(if/else
、for
ループなど)が複雑すぎたり、トレ賛成できない外部関数を呼び出している場合に起こりやすいです。
トラブルシューティング
- デバッグ用ログ
カスタムInterpreter
を使用している場合は、run_node
やcall_function
などのメソッドにデバッグ用のログを追加して、各ノードの入力と出力を追跡します。 - run()への引数の確認
interpreter.run(*args, **kwargs)
に渡す引数が、期待されるテンソルの形、型、値であることを確認します。 - トレースの確認
print(traced_module.graph)
を実行して、生成されたグラフを確認します。特にplaceholder
ノードや、None
になる可能性のあるノードの周辺を注意深く見ます。- モデルがトレース可能であるか確認します。特に、データに依存する制御フローや、PyTorchのテンソル操作に変換できない外部ライブラリの呼び出しは、トレースの妨げになります。
RuntimeError: The following operation failed in the TorchScript interpreter. ... (バックエンドエラー)
原因
これは、Interpreter
がグラフ内のPyTorch操作を実行しようとしたときに、その操作自体が何らかの理由で失敗した場合に発生します。これはFX特有のエラーというよりは、通常のPyTorchの実行時エラーに近いことが多いです。
考えられる原因
- 数値的な不安定性
非常に大きな値や小さな値、NaN
/inf
などによる計算エラー。 - メモリ不足
GPUメモリが足りない。 - デバイスの不一致
GPUとCPUのテンソルが混在している。 - 入力の不一致
テンソルの形状が演算に合致しない(例: 行列積の次元不一致)。
トラブルシューティング
- モデルの通常実行との比較
トレースされたモデルではなく、元のnn.Module
を同じ入力で実行してみて、同様のエラーが発生するか確認します。もし発生するなら、それはFXの問題ではなく、モデル自体の問題です。 - 入力テンソルのチェック
問題の操作に渡されるテンソルの形状、型、dtype
、デバイスをデバッグで確認します。 - スタックトレースの確認
エラーメッセージに続くスタックトレースを注意深く読み、どのPyTorch操作で問題が発生したかを確認します。
KeyError: 'node_name'
原因
これは、Interpreter
がグラフ内のノードの出力を参照しようとしたときに、そのノードの名前が見つからない場合に発生します。これは通常、カスタムInterpreter
で、node.name
に基づくキャッシュやルックアップに問題がある場合に発生します。
考えられる原因
- グラフの変更
実行時にプログラム的にグラフを変更していて、その変更がInterpreter
の期待と異なる場合。 - カスタムInterpreterのバグ
Interpreter
をオーバーライドしていて、self.environment
(ノードの出力を格納する辞書)へのアクセスや更新が正しくない。
トラブルシューティング
- print(traced_module.graph)
グラフ内のノード名が期待通りであるか確認します。 - カスタムInterpreterのコードレビュー
run_node
などのオーバーライドされたメソッドで、ノードの出力をself.environment[node.name] = result
のように正しく格納しているか確認します。
AssertionError (カスタムロジック関連)
原因
カスタムのInterpreter
を記述している場合、開発者が追加したassert
文が満たされなかったときに発生します。
トラブルシューティング
- 入力データの確認
アサーションの条件を満たすために、run()
に渡す入力データが適切であるか確認します。 - アサーションの条件確認
どのassert
文が失敗したかを確認し、その条件がなぜ満たされなかったのかをデバッグします。関連する変数の値を確認します。
ModuleNotFoundError や ImportError
原因
これはFX自体というよりは、Interpreter
が実行しようとしているモデルや、Interpreter
自体が依存しているモジュールが見つからない場合に発生します。
考えられる原因
- パスの問題
Pythonのパスが正しく設定されていない。 - 環境設定の誤り
必要なライブラリがインストールされていない。
PYTHONPATH
などの環境変数をチェックします。pip list
やconda list
で必要なライブラリがインストールされているか確認します。
- FXトレースの制限を理解する
FXはPythonの動的な性質のすべてをトレースできるわけではありません。特に、テンソルに依存する制御フロー、直接的なPythonのリスト操作、isinstance()
などの型チェック、カスタムクラスのインスタンス化などはトレースが難しい場合があります。 - print(traced_module.graph)
生成されたグラフを出力して、FXがモデルをどのように解釈したかを確認します。これにより、トレースが期待通りに行われたかどうかがわかります。特に、call_function
、call_module
、call_method
ノードが正しく捕捉されているかを確認します。 - FXの公式ドキュメントと例
PyTorchの公式ドキュメントやFXのチュートリアルに目を通し、同様の問題が報告されていないか確認します。 - 最小限の再現コード
問題を再現する最小限のコードスニペットを作成します。これにより、問題の原因を特定しやすくなります。 - ステップ実行とデバッグ
pdb
やIDEのデバッガを使って、Interpreter
のrun()
メソッド内をステップ実行し、各ノードの入力と出力、self.environment
の状態を確認することが非常に有効です。
例1: 基本的な Interpreter.run()
の使用
この例では、ごくシンプルなPyTorchモデルをトレースし、Interpreter
を使ってそのグラフを実行する方法を示します。
import torch
import torch.nn as nn
from torch.fx import symbolic_trace, Interpreter
# 1. シンプルなPyTorchモデルを定義
class MySimpleModule(nn.Module):
def forward(self, x):
# x + (x * 2) のような計算
a = x * 2.0
b = x + a
return b
# 2. モデルをシンボリックトレース
# ダミー入力を使ってグラフを構築
dummy_input = torch.tensor(3.0)
traced_module = symbolic_trace(MySimpleModule())
print("--- 生成されたグラフ ---")
traced_module.graph.print_tabular()
print("\n")
# 3. Interpreterインスタンスを作成
# InterpreterはGraphModuleを受け取ります
interpreter = Interpreter(traced_module)
# 4. Interpreter.run() を使ってグラフを実行
# run()には、グラフのプレースホルダーノードに対応する引数を渡します
input_value = torch.tensor(5.0)
output_from_interpreter = interpreter.run(input_value)
# 5. 元のモデルの実行と比較
output_from_original_module = MySimpleModule()(input_value)
print(f"入力値: {input_value}")
print(f"Interpreterによる出力: {output_from_interpreter}")
print(f"元のモジュールによる出力: {output_from_original_module}")
# 両者の出力が一致することを確認
assert torch.allclose(output_from_interpreter, output_from_original_module)
print("\nInterpreterの出力と元のモジュールの出力は一致します。")
解説
MySimpleModule
という非常に基本的なnn.Module
を定義します。torch.fx.symbolic_trace()
を使って、このモデルのforward
メソッドの計算グラフを抽出します。結果はtorch.fx.GraphModule
になります。Interpreter(traced_module)
を使って、このGraphModule
を実行するためのInterpreter
インスタンスを作成します。interpreter.run(input_value)
を呼び出して、グラフに定義された計算を実行します。input_value
は、グラフの入力(placeholder
ノード)に対応します。- 元の
MySimpleModule
を直接実行した結果と比較し、Interpreter
が正しくグラフを評価していることを確認します。
例2: Interpreter
のカスタマイズ - 中間結果のフック
Interpreter
を継承し、run_node
メソッドをオーバーライドすることで、グラフの各ノードが実行される前後にカスタムロジックを挿入できます。この例では、各ノードの計算結果をフックして表示します。
import torch
import torch.nn as nn
from torch.fx import symbolic_trace, Interpreter, GraphModule
# 1. 少し複雑なモデルを定義
class ComplexModule(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# 2. カスタムInterpreterクラスを定義
class MyCustomInterpreter(Interpreter):
def __init__(self, gm: GraphModule):
super().__init__(gm)
self.node_outputs = {} # 各ノードの出力値を保存する辞書
# run_node メソッドをオーバーライド
# このメソッドは、グラフ内の各ノードが実行されるたびに呼び出されます
def run_node(self, node: torch.fx.Node) -> any:
# ノードが実行される前の処理
print(f"--- ノード実行前: {node.name} ({node.op}) ---")
# 親クラスのrun_nodeを呼び出して、実際のノードの計算を実行
# これがノードの出力値を返します
output = super().run_node(node)
# ノードが実行された後の処理
self.node_outputs[node.name] = output # 出力値を保存
print(f"ノード実行後: {node.name}, 出力値の形状: {output.shape if isinstance(output, torch.Tensor) else 'N/A'}")
print(f"出力値の型: {type(output)}\n")
return output
# 3. モデルをシンボリックトレース
dummy_input = torch.randn(1, 10)
traced_module = symbolic_trace(ComplexModule())
print("--- 生成されたグラフ ---")
traced_module.graph.print_tabular()
print("\n")
# 4. カスタムInterpreterインスタンスを作成
custom_interpreter = MyCustomInterpreter(traced_module)
# 5. run() を使ってグラフを実行
input_tensor = torch.randn(1, 10) # バッチサイズ1、特徴量10の入力
output_from_custom_interpreter = custom_interpreter.run(input_tensor)
print(f"最終的なInterpreterによる出力: {output_from_custom_interpreter}")
# 保存された中間結果を表示 (例: 'relu_1'ノードの出力)
if 'relu_1' in custom_interpreter.node_outputs:
print(f"\n'relu_1' ノードの出力 (カスタムInterpreterから):")
print(custom_interpreter.node_outputs['relu_1'])
解説
ComplexModule
という、より一般的なニューラルネットワークモジュールを定義します。MyCustomInterpreter
クラスをInterpreter
から継承します。run_node(self, node)
メソッドをオーバーライドします。- このメソッドの内部では、
super().run_node(node)
を呼び出すことで、親クラスの本来のノード実行ロジックを呼び出しています。 - その前後で、各ノードの実行状況や出力値を表示・保存するカスタムロジックを追加しています。
- このメソッドの内部では、
custom_interpreter.run(input_tensor)
を呼び出すと、定義したrun_node
のロジックが各ノードに対して実行され、中間結果がログに出力されます。custom_interpreter.node_outputs
辞書には、各ノードの最終的な出力が保存されます。
このカスタマイズは、デバッグ、プロファイリング、または特定のノードの結果を監視する際に非常に役立ちます。
Interpreter
を継承して、特定の種類の操作(call_function
, call_module
, call_method
)の挙動をオーバーライドすることで、グラフの実行時にその操作を別のロジックに置き換えることができます。これは、例えば、特定の演算を高速なカスタム実装に置き換えたり、複雑なモジュールを簡単なモックに置き換えたりするのに便利です。
import torch
import torch.nn as nn
from torch.fx import symbolic_trace, Interpreter, GraphModule, Node
# 1. 特徴的な演算を含むモデル
class SpecialOpModule(nn.Module):
def forward(self, x, y):
# ここでは torch.add を特殊な演算とみなす
sum_val = torch.add(x, y)
result = sum_val * 2
return result
# 2. torch.add をカスタムロジックに置き換えるInterpreter
class CustomAddInterpreter(Interpreter):
def __init__(self, gm: GraphModule):
super().__init__(gm)
# call_function をオーバーライド
# これは、torch.add(x, y) のような関数呼び出しノードを処理します
def call_function(self, target: torch.fx.subgraph_rewriter.Callable, args, kwargs) -> any:
# target が torch.add であるかチェック
if target == torch.add:
print(f"--- カスタムロジック: torch.add({args[0].item()}, {args[1].item()}) を呼び出し中 ---")
# 通常の加算ではなく、減算を実行する (デモンストレーション目的)
custom_result = args[0] - args[1]
return custom_result
else:
# その他の関数呼び出しは親クラスのデフォルト実装に任せる
return super().call_function(target, args, kwargs)
# 3. モデルをシンボリックトレース
dummy_x = torch.tensor(1.0)
dummy_y = torch.tensor(2.0)
traced_module = symbolic_trace(SpecialOpModule())
print("--- 生成されたグラフ ---")
traced_module.graph.print_tabular()
print("\n")
# 4. カスタムInterpreterインスタンスを作成
custom_add_interpreter = CustomAddInterpreter(traced_module)
# 5. run() を使ってグラフを実行
input_x = torch.tensor(10.0)
input_y = torch.tensor(3.0)
print(f"入力 x: {input_x.item()}, y: {input_y.item()}")
output_from_custom_interpreter = custom_add_interpreter.run(input_x, input_y)
print(f"カスタムInterpreterによる出力: {output_from_custom_interpreter.item()}")
# 6. 元のモデルの実行結果と比較 (add が行われる場合)
output_from_original_module = SpecialOpModule()(input_x, input_y)
print(f"元のモジュールによる出力 (add が行われる場合): {output_from_original_module.item()}")
解説
SpecialOpModule
はtorch.add
を使用するシンプルなモデルです。CustomAddInterpreter
を定義し、call_function
メソッドをオーバーライドします。- このメソッド内で、
target
がtorch.add
であるかをチェックします。 - もし
torch.add
であれば、本来の加算ではなく、デモンストレーションとして減算を実行し、その結果を返します。 else
ブロックでは、それ以外の関数呼び出しはsuper().call_function
に任せ、デフォルトの挙動を保ちます。
- このメソッド内で、
custom_add_interpreter.run()
を実行すると、グラフ内のtorch.add
ノードが検出され、定義したカスタムの減算ロジックが適用されます。
この機能は、モデルの量子化、特定の演算のグラフ最適化、またはテスト時に外部APIの呼び出しをモックするなどの高度なユースケースで非常に役立ちます。
ここでは、それらの代替方法と、それぞれのユースケースについて解説します。
直接 GraphModule を呼び出す (__call__ メソッド)
説明
torch.fx.symbolic_trace()
によって生成された GraphModule
は、通常の nn.Module
と同じように直接呼び出すことができます。実際、ほとんどの標準的なユースケースでは、これが最も一般的で推奨されるグラフ実行方法です。
GraphModule
の __call__
メソッドは、内部的にグラフを走査し、各ノードに対応する演算を実行します。これは、Interpreter
が行うのと同じ計算ロジックですが、ユーザーがフックやカスタマイズを挿入するための直接的なAPIは提供されません。
コード例
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
class MyModule(nn.Module):
def forward(self, x):
return x + 1
dummy_input = torch.tensor(10.0)
traced_module = symbolic_trace(MyModule())
print("--- 生成されたグラフ ---")
traced_module.graph.print_tabular()
print("\n")
# GraphModuleを直接呼び出す
input_val = torch.tensor(5.0)
output = traced_module(input_val) # ここが直接呼び出し
print(f"直接呼び出しによる出力: {output}")
# 元のモジュールの結果と比較
original_output = MyModule()(input_val)
assert torch.allclose(output, original_output)
print("直接呼び出しの出力は元のモジュールと一致します。")
ユースケース
- 推論
推論時にモデルを効率的に実行したい場合。 - パフォーマンス
カスタムロジックを挿入しないため、通常はInterpreter
を介するよりもオーバーヘッドが少ない可能性があります。 - 最も一般的
グラフを生成し、そのまま実行したい場合。
Interpreter.run() との比較
- カスタマイズ性
Interpreter
のような実行時のフックやノードごとのロジック変更はできません。 - シンプルさ
コードが最もシンプルで、追加のクラス定義やメソッドオーバーライドが不要です。
torch.jit.script / torch.jit.trace (TorchScript)
説明
torch.fx
とは異なるアプローチでPyTorchモデルをグラフ表現に変換・最適化する技術です。
torch.jit.trace
: ダミー入力を使ってモデルの実行パスを「記録」し、そのパス上の操作をTorchScript IR に変換します。実行時に通らなかったパスはグラフに含まれません。torch.jit.script
: Pythonコードを直接解析し、TorchScript IR (中間表現) に変換します。動的な制御フロー(if/else
、ループなど)もサポートします。
FXはPythonのコンパイラフロントエンドとして設計されており、グラフをSymbolic Python (fx.Graph
)として表現し、後続の変換や最適化に適しています。TorchScriptはデプロイに特化しており、C++ランタイムでの実行やモデルのエクスポート(ONNXなど)に適しています。
コード例 (torch.jit.script)
import torch
import torch.nn as nn
class ScriptableModule(nn.Module):
def forward(self, x):
if x.sum() > 0:
return x * 2
else:
return x / 2
# モデルをスクリプト化
scripted_module = torch.jit.script(ScriptableModule())
print("--- スクリプト化されたモデル ---")
# scripted_module は GraphModule とは異なる内部表現を持つ
print(scripted_module.graph) # TorchScript IR を表示
print("\n")
# スクリプト化されたモデルを実行
input_pos = torch.tensor([1.0, 2.0])
output_pos = scripted_module(input_pos)
print(f"入力 ([1.0, 2.0]) -> 出力: {output_pos}")
input_neg = torch.tensor([-1.0, -2.0])
output_neg = scripted_module(input_neg)
print(f"入力 ([-1.0, -2.0]) -> 出力: {output_neg}")
# 元のモジュールと比較
assert torch.allclose(output_pos, ScriptableModule()(input_pos))
assert torch.allclose(output_neg, ScriptableModule()(input_neg))
ユースケース
- 動的な制御フロー
torch.jit.script
は、if/else
やループのようなPythonの制御フローをグラフに含める必要がある場合に特に強力です。(torch.fx
は制御フローを扱うのがより複雑です) - Pythonからの独立性
Pythonインタープリタに依存しない形でモデルを実行したい場合。 - モデルのエクスポート
ONNXなどのフォーマットにモデルをエクスポートする前処理として。 - デプロイ
モデルをプロダクション環境(C++、モバイル、エッジデバイスなど)にデプロイしたい場合。
Interpreter.run() との比較
- 内部表現
FXはPythonのオペコードに近い高レベルなグラフ表現を持つ一方、TorchScriptはより低レベルで実行に最適化されたIRを持ちます。 - 目的が異なる
Interpreter
がグラフのデバッグ、分析、変換のためのツールであるのに対し、TorchScriptは主にデプロイと最適化を目的とします。
torch.compile (Dynamo, Inductor)
説明
PyTorch 2.0 で導入された torch.compile
は、モデルをより高速に実行するための新しいコンパイルスタックです。内部的にFXとDynamo(バイトコードインタープリタ)を利用してPythonコードをPyTorch IRに変換し、それをInductorなどのバックエンドで最適化されたコード(CUDA C++など)にコンパイルします。
torch.compile
は、ユーザーが明示的にFXグラフを操作する必要なく、パフォーマンスの向上を自動的に提供することを目指しています。
コード例
import torch
import torch.nn as nn
class FastModule(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = FastModule()
dummy_input = torch.randn(1, 10)
# モデルをコンパイル
compiled_model = torch.compile(model)
# コンパイルされたモデルを実行 (初回はコンパイルが発生)
output_compiled = compiled_model(dummy_input)
print(f"torch.compileによる出力: {output_compiled.shape}")
# 通常のモデルの実行
output_original = model(dummy_input)
# 結果が一致することを確認
assert torch.allclose(output_compiled, output_original)
print("torch.compileの出力は元のモデルと一致します。")
ユースケース
- 既存コードへの適用
既存のPyTorchコードに最小限の変更(torch.compile
でラップするだけ)で適用したい場合。 - 自動最適化
手動でのグラフ変換や最適化の知識なしに、自動的に高いパフォーマンスを得たい場合。 - パフォーマンスの向上
モデルのトレーニングや推論の速度を向上させたい場合。
Interpreter.run() との比較
- 目的
Interpreter
が分析・変換ツールであるのに対し、torch.compile
はパフォーマンス最適化ツールです。 - 抽象度
torch.compile
はユーザーからグラフ操作の詳細を隠蔽し、パフォーマンス向上に特化しています。Interpreter
はグラフの低レベルな分析やカスタムロジックの挿入に焦点を当てています。
説明
これは直接的な実行方法ではありませんが、Interpreter.run()
を使ってグラフを実行する前に、あるいはその代替として、GraphModule
自体を変換(書き換え、最適化)する方法です。torch.fx
は、グラフを操作するための豊富なAPIを提供します。
- カスタムパス
グラフを変換する独自のロジックを実装し、それらをパスとして適用する。 fx.subgraph_rewriter
: 特定のサブグラフパターンを別のサブグラフに置き換える。- 手動でのグラフ変更
Graph
オブジェクトに直接アクセスし、ノードの追加、削除、接続の変更などを行う。
コード例 (概念的)
# これは実行可能なコードではありませんが、概念を示します
import torch
import torch.nn as nn
from torch.fx import symbolic_trace, GraphModule, Node
from torch.fx.subgraph_rewriter import rewrite_pattern_from_src_and_repl
# 変換前のモデル
class OriginalModule(nn.Module):
def forward(self, x):
a = x * 2
b = a + a # ここを置き換えたい
return b
# トレース
traced_module = symbolic_trace(OriginalModule())
# グラフを変換する例 (手動でノードを置き換える)
# これは非常に単純化した例であり、実際のグラフ変換はより複雑です
for node in traced_module.graph.nodes:
if node.op == 'call_function' and node.target == torch.add and len(node.args) == 2 and node.args[0] == node.args[1]:
# x + x を x * 2 に置き換える
with traced_module.graph.inserting_after(node):
new_node = traced_module.graph.call_function(torch.mul, (node.args[0], torch.tensor(2.0)))
node.replace_all_uses_with(new_node)
traced_module.graph.erase_node(node)
traced_module.recompile() # グラフの変更を適用
print("--- 変換後のグラフ ---")
traced_module.graph.print_tabular()
# 変換後のGraphModuleを直接実行
output = traced_module(torch.tensor(3.0))
print(f"変換後のモジュール出力: {output}")
# `torch.fx.Interpreter.run()` は、このような変換されたGraphModuleの実行にも使用できます。
# interpreter = Interpreter(traced_module)
# output_interp = interpreter.run(torch.tensor(3.0))
ユースケース
- モデル分析
グラフをより分析しやすい形に変換する。 - 量子化
モデルの量子化パスをグラフに適用する。 - ハードウェア固有の最適化
特定のデバイスのアクセラレーターに合わせたグラフ変換。 - グラフレベルの最適化
冗長な演算の削除、演算の融合(例: Conv-BN融合)、パターンマッチングと置き換え。
- 補完的
グラフを変換した後に、その新しいグラフをInterpreter
で実行してデバッグしたり、直接呼び出してパフォーマンスをテストしたりできます。 - 目的が異なる
Interpreter
がグラフの実行時にカスタムロジックを挿入するのに対し、グラフ変形はグラフの構造自体を永続的に変更します。
方法 | 説明 | ユースケース | Interpreter.run() との主な違い |
---|---|---|---|
GraphModule の直接呼び出し | fx.symbolic_trace() で生成された GraphModule を通常の nn.Module として実行。 | 最も一般的、シンプルな実行、推論 | カスタマイズ性なし、最も直接的で軽量な実行 |
torch.jit.script / torch.jit.trace | モデルをTorchScript IRに変換し、Pythonから独立して実行可能にする。 | デプロイ(C++/モバイル)、ONNXエクスポート、動的な制御フローのサポート | 主目的はデプロイと最適化、異なるIR、FXよりも低レベル |
torch.compile | PyTorch 2.0の新機能で、モデルを自動的に高速化するためにFXなどを内部で利用。 | パフォーマンス向上(トレーニング/推論)、自動最適化、既存コードへの容易な適用 | ユーザーからグラフの詳細を隠蔽、パフォーマンス最適化が主目的 |
カスタム GraphModule 変形 | Graph オブジェクトを直接操作したり、fx.subgraph_rewriter などで構造を変更する。 | グラフレベルの最適化(演算融合、冗長削除)、ハードウェア固有の最適化、量子化、モデル分析 | グラフ構造自体を永続的に変更、実行時のフックではなく、実行前の準備 |