実践!torch.fx.Graph.nodesを使ったPyTorchモデルのグラフ変換テクニック
その中で、torch.fx.Graph.nodes
は、この生成された計算グラフを構成する個々の「ノード」のリストまたはイテラブルなコレクションを指します。
各ノードは、グラフ内の特定の操作や値を示します。具体的には、以下のような種類のノードがあります。
-
output (出力):
forward
メソッドの最終的な戻り値を表します。
-
call_method (メソッド呼び出し):
x.mean()
のような、テンソルオブジェクトのメソッド呼び出しを表します。
-
call_module (モジュール呼び出し):
self.linear(x)
のように、nn.Module
のインスタンス(サブモジュール)のforward
メソッドの呼び出しを表します。
-
call_function (関数呼び出し):
torch.relu
やtorch.add
のようなPythonの関数(torch.Tensor
に対する操作を含む)の呼び出しを表します。
-
get_attr (アトリビュート取得):
- モデルの属性(例:
self.param
のようなnn.Parameter
やself.linear
のようなサブモジュール)へのアクセスを表します。
- モデルの属性(例:
-
placeholder (プレースホルダー):
forward
メソッドへの入力(引数)を表します。例えば、forward(self, x)
のx
などがこれに該当します。
torch.fx.Graph.nodes
を使うことで、生成された計算グラフ内の各操作や値にアクセスし、それらを検査したり、順序を変更したり、新しいノードを追加したり、既存のノードを置き換えたりといった、グラフレベルでの変換を行うことができます。これは、モデルの最適化(例:畳み込みとバッチ正規化の融合)、量子化、カスタムな振る舞いの挿入などに非常に役立ちます。
Graph Break (グラフブレイク)
これが最も一般的な問題であり、torch.fx
(特にtorch.compile
と組み合わせて使用する場合)の性能最適化を阻害する主な要因です。
エラー/症状
torch.fx.symbolic_trace
が一部のコードをトレースできず、結果として生成されるグラフが不完全になる。torch.compile
を使用している場合、"Graph break"に関する警告やログが出力される。- モデルが期待通りにコンパイルされない、または最適化の恩恵を受けられない。
原因
torch.fx.symbolic_trace
は、Pythonのコードをシンボリックに実行してグラフを構築します。しかし、全てのPythonコードをトレースできるわけではありません。特に、以下のようなケースでグラフブレイクが発生しやすいです。
- テンソルの形状が動的すぎる場合
torch.compile
では、dynamic=True
を設定しない限り、テンソルの形状が固定されていることを前提とすることがあります。実行ごとに形状が変わると、再コンパイル(recompilation)が頻繁に発生し、オーバーヘッドが大きくなります。
- インプレース操作の誤用
- 不適切なインプレース操作(例:
x += y
)が、グラフの構築を妨げることがあります。
- 不適切なインプレース操作(例:
- 動的な属性アクセス (Dynamic attribute access)
getattr(self, 'linear_' + str(i))
のように、文字列操作によって動的にモジュールやパラメータにアクセスする場合、トレース時にどの属性が参照されるか特定できないことがあります。
- サポートされていないPythonの組み込み関数やモジュール
print()
以外の多くのPython組み込み関数や、特定の外部ライブラリの関数は、torch.fx
がそのセマンティクスを理解できないため、グラフブレイクの原因となります。
- データ依存の制御フロー (Data-dependent control flow)
- テンソルの値に依存する
if
文やfor
ループ。例えば、if x.sum() > 0:
のような場合、x.sum()
の値がトレース時に確定しないため、パスが決定できません。
- テンソルの値に依存する
トラブルシューティング
- 動的形状の有効化 (dynamic=True)
テンソルの形状が実行時に変化する場合、torch.compile(model, dynamic=True)
を設定することで、再コンパイルの頻度を減らせる可能性があります。 - fullgraph=Trueの利用 (torch.compile)
これを設定すると、グラフブレイクが発生した場合にエラーを発生させることができます。これにより、どこでグラフが壊れているかを強制的に特定できます。 - torch.fx.wrap()の使用
トレースできない関数やモジュールを一時的にラップして、FXがその内部を見ないようにすることができます。ただし、これは最適化の機会を失う可能性があります。 - コードの簡素化
グラフブレイクの原因となっている部分を特定し、torch.fx
がトレースしやすいようにコードを書き直します。例えば、データ依存の条件分岐を避ける、サポートされていないPython関数をPyTorchの同等の操作に置き換えるなどです。 - torch._dynamo.config.log_level = logging.INFO (または DEBUG) の設定
これにより、より詳細なデバッグログが出力され、問題の特定に役立ちます。 - torch._dynamo.explain()の使用
torch.compile
を使っている場合、これを使用すると、どこでグラフブレイクが発生し、その理由が何であるか詳細なレポートが得られます。
Nodeの不適切な変更または挿入
graph.nodes
を直接操作する場合、グラフの整合性を損なう可能性があります。
エラー/症状
- グラフを
GraphModule
に変換した後に、モデルが正しく実行されない、または間違った結果を出す。 AttributeError: 'Node' object has no attribute 'xxx'
(存在しない属性にアクセスしようとした場合)RuntimeError: The object you are trying to fetch is a dead Node. This means the Node has been removed from the Graph and cannot be used in a subsequent operation.
原因
- グラフの接続の破損
ノードのusers
(そのノードの出力を利用するノード)やall_input_nodes
(そのノードの入力となるノード)の管理が正しく行われないと、グラフの接続が壊れます。 - 不適切なargs/kwargsの変更
ノードのargs
やkwargs
を、その操作が期待する入力と異なるものに変更すると、実行時にエラーになったり、計算が不正になったりします。 - ノードの削除後の参照
graph.nodes
からノードを削除したり、node.replace_all_uses_with()
などでノードを置き換えたりした後、古いノードオブジェクトを参照し続けようとするとエラーになります。
トラブルシューティング
- デバッグ時のステップ実行
グラフを操作する際に、各ステップでprint(graph)
やfor node in graph.nodes: print(node)
を実行し、グラフの状態を確認しながら進めることで、意図しない変更を早期に発見できます。 - graph.eliminate_dead_code()の使用
不要になったノードを削除するために、グラフ操作の後にこのメソッドを呼び出すと良いでしょう。 - 変更後のGraphModuleの再構築
グラフを変更した後、必ずtorch.fx.GraphModule(model, graph)
またはgraph.recompile()
を呼び出して、変更を反映させる必要があります。 - NodeのAPIを正しく理解する
Node
のドキュメントをよく読み、replace_all_uses_with()
,append()
,prepend()
,insert_node_after()
,insert_node_before()
などのメソッドの正しい使い方を理解することが重要です。
torch.fx.symbolic_trace
が、モデルの特定の振る舞いを正確にキャプチャできないことがあります。
エラー/症状
- 特定の層や操作がグラフに現れない。
- 生成されたグラフが、元のモデルの
forward
メソッドと異なる動作をする。
原因
- PyTorchの内部的な低レベル操作
一部のPyTorchの内部的なC++レベルの操作は、fx
のシンボリックトレースの範囲外にある場合があります。 - テンソル以外のオブジェクトの操作
torch.fx
は主にテンソル操作に焦点を当てています。リスト、辞書、NumPy配列などのテンソル以外のオブジェクトに対する複雑な操作は、グラフに適切に表現されないことがあります。 - Pythonの動的な機能
非常に動的なPythonの機能(例:exec()
,eval()
,setattr()
の複雑な使用)はトレースが困難です。
- torch.fx.Interpreterの使用
fx.Interpreter
を使ってグラフをステップ実行し、各ノードの入力と出力のテンソルを検査することで、問題のある箇所を特定できます。 - Nodeのopとtargetを確認
各ノードのop
(操作の種類)とtarget
(呼び出される関数、モジュール、メソッドなど)が、意図したものであることを確認します。 - GraphModuleを実行して比較
torch.fx.symbolic_trace
で得られたGraphModule
を、元のnn.Module
と同じ入力で実行し、出力が一致するかどうかを確認します。 - print(graph)で生成されたグラフを確認
グラフのテキスト表現を確認し、期待するノードが全て存在し、正しく接続されているかを視覚的に確認します。
例1: グラフの生成とノードの表示
まず、簡単なモデルを作成し、それをtorch.fx.symbolic_trace
でトレースしてグラフを生成し、そのノードを全て表示する例です。
import torch
import torch.nn as nn
import torch.fx
# 1. シンプルなモデルの定義
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 1)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# モデルのインスタンス化
model = MyModel()
# 2. モデルをシンボリックトレースしてグラフを生成
# ダミーの入力テンソルが必要です
example_inputs = torch.randn(1, 10)
traced_model = torch.fx.symbolic_trace(model)
# 生成されたGraphオブジェクトを取得
graph = traced_model.graph
print("--- Generated Graph Nodes ---")
# 3. graph.nodesを走査して各ノードを表示
for node in graph.nodes:
print(f"Node: {node.name}, Op: {node.op}, Target: {node.target}, Args: {node.args}, Kwargs: {node.kwargs}")
print("\n--- Graph Visualization (Optional) ---")
# グラフのテキスト表現
print(graph)
出力の解説
上記のコードを実行すると、以下のような出力(内容は環境によって多少異なる可能性があります)が得られます。
--- Generated Graph Nodes ---
Node: x, Op: placeholder, Target: x, Args: (), Kwargs: {}
Node: linear1, Op: call_module, Target: linear1, Args: (x,), Kwargs: {}
Node: relu, Op: call_module, Target: relu, Args: (linear1,), Kwargs: {}
Node: linear2, Op: call_module, Target: linear2, Args: (relu,), Kwargs: {}
Node: output, Op: output, Target: output, Args: ((linear2,),), Kwargs: {}
--- Graph Visualization (Optional) ---
# ... (graphのテキスト表現)
output
: グラフの最終出力。入力はlinear2
の出力。op: output
linear2
:self.linear2
モジュールの呼び出し。入力はrelu
の出力。op: call_module
,target: linear2
relu
:self.relu
モジュールの呼び出し。入力はlinear1
の出力。op: call_module
,target: relu
linear1
:self.linear1
モジュールの呼び出し。入力はx
。op: call_module
,target: linear1
x
:forward
メソッドの入力引数。op: placeholder
このように、graph.nodes
をイテレートすることで、グラフを構成する各操作とその詳細(操作の種類、ターゲット、引数)をプログラム的に検査できます。
例2: 特定の種類のノードの検索と情報取得
グラフの中から特定の種類のノード(例: call_module
ノード)を探し、その情報を取得する例です。
import torch
import torch.nn as nn
import torch.fx
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.activation = nn.ReLU()
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.activation(x)
x = self.pool(x)
return x
model = SimpleModel()
traced_model = torch.fx.symbolic_trace(model, example_inputs=torch.randn(1, 3, 32, 32))
graph = traced_model.graph
print("\n--- Listing all 'call_module' nodes ---")
for node in graph.nodes:
if node.op == 'call_module':
# ノードが参照しているサブモジュールをモデルから取得
# node.target は文字列 (例: 'conv', 'bn')
submodule = traced_model.get_submodule(node.target)
print(f" Module Node: {node.name}, Target: {node.target}, Type: {type(submodule).__name__}")
if isinstance(submodule, nn.Conv2d):
print(f" - Conv2d: in_channels={submodule.in_channels}, out_channels={submodule.out_channels}")
解説
traced_model.get_submodule(node.target)
を使って、その名前から実際のnn.Module
インスタンスを取得し、さらに詳細な情報を検査しています。node.target
は、呼び出されたモジュールの名前(nn.Module
で登録した名前)を表す文字列です。node.op == 'call_module'
でモジュール呼び出しを表すノードをフィルタリングしています。
例3: グラフの変換(ノードの挿入)
グラフに新しいノードを挿入する例です。ここでは、BatchNorm2d
の後ろにカスタムの統計計算ノードを挿入することを考えます。
import torch
import torch.nn as nn
import torch.fx
# カスタムの関数(FXでトレース可能なシンプルなもの)
def custom_statistics(x):
# 例として、入力テンソルの平均と標準偏差を計算し、元のテンソルと結合して返す
# FXでトレース可能にするために、テンソル操作に限定する
mean = torch.mean(x)
std = torch.std(x)
# テンソルとして返す必要がある
return x, mean, std
class ModelWithBn(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # ここに何かを挿入したい
x = self.relu(x)
return x
model = ModelWithBn()
traced_model = torch.fx.symbolic_trace(model, example_inputs=torch.randn(1, 3, 32, 32))
graph = traced_model.graph
print("--- Original Graph ---")
print(graph)
# グラフを操作する際には、ノードのリストを動的に変更するため、慎重に行う
# 一般的には、ノードのイテレーション中にリストを変更しないように注意するが、
# FXのノード挿入APIは安全に設計されている
# 挿入位置を探す: 'bn' モジュール呼び出しノードの後ろ
bn_node = None
for node in graph.nodes:
if node.op == 'call_module' and node.target == 'bn':
bn_node = node
break
if bn_node:
# 既存のノード(bn_node)の後ろに新しいノードを挿入
# custom_statistics関数を呼び出すノードを作成
with graph.inserting_after(bn_node):
# 新しいノードのオペレーションは 'call_function' (通常のPython関数)
# target は呼び出す関数オブジェクト自体
# args はその関数に渡す引数。bn_nodeの出力がこの関数の入力になる
custom_stats_node = graph.call_function(custom_statistics, args=(bn_node,))
# custom_stats_node の出力は (original_x, mean, std) のタプルであると仮定する
# そのため、元の計算フロー(reluへの入力)はタプルの最初の要素を使うように変更する
# bn_nodeの出力を利用していたノード(ここではrelu_node)を探す
# bn_node.users を使うと、bn_nodeの出力を入力として使用しているノードの集合が得られる
for user_node in list(bn_node.users): # usersはSetなのでコピーを作成
# user_nodeの引数リストを検査し、bn_nodeが入力として使われている部分を置き換える
# ここでは簡単のため、最初の引数がbn_nodeであると仮定
if user_node.args and user_node.args[0] == bn_node:
# user_nodeの引数を、custom_stats_nodeの最初の出力(元のテンソル)に置き換える
# Note: call_function がタプルを返す場合、そのタプルの要素もNodeとして扱われる
user_node.args = (custom_stats_node.all_output_nodes[0],) + user_node.args[1:]
# 新しいノードが作成されたので、GraphModuleを再構築する必要がある
# このステップは、グラフの変更を実際に反映させるために非常に重要
traced_model.recompile() # FX 0.10+ で推奨される方法
# または new_traced_model = torch.fx.GraphModule(model, graph)
print("\n--- Modified Graph (with custom_statistics) ---")
print(graph)
# 変更後のモデルをテスト
# out_x, out_mean, out_std = traced_model(torch.randn(1, 3, 32, 32))
# print(f"Output x shape: {out_x.shape}")
# print(f"Output mean: {out_mean}")
# print(f"Output std: {out_std}")
else:
print("BatchNorm node not found.")
解説
- 挿入位置の特定
まず、bn
モジュール呼び出しノード(bn_node
)を検索します。 - graph.inserting_after()
これはコンテキストマネージャーで、このブロック内で作成される新しいノードがbn_node
の直後に挿入されるようにします。 - graph.call_function()
新しいノードを作成するために使用します。custom_statistics
関数を呼び出すノードを作成します。args=(bn_node,)
で、bn_node
の出力(つまりbn
モジュールの出力)がcustom_statistics
関数の入力になるように設定します。
- 出力の再配線
custom_statistics
が元のテンソルと統計情報のタプルを返すため、bn_node
の出力を利用していた後続のノード(ここではrelu
ノード)が、custom_stats_node
の最初の出力(元のテンソル)を引数として受け取るように、そのargs
を更新します。bn_node.users
を使って、bn_node
の出力を利用しているノードを見つけます。custom_stats_node.all_output_nodes[0]
は、custom_stats_node
が返すタプルの最初の要素を表すノードです。
- traced_model.recompile()
グラフの変更を実際にモデルに適用するために、このメソッドを呼び出すことが非常に重要です。これを忘れると、モデルは変更前のグラフのまま動作します。
グラフから特定のノードを削除したり、別のノードに置き換えたりする例です。ここではBatchNorm2d
を削除する(またはパススルーにする)ことを考えます。
import torch
import torch.nn as nn
import torch.fx
class ModelWithBn(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # このノードを削除または置き換えたい
x = self.relu(x)
return x
model = ModelWithBn()
traced_model = torch.fx.symbolic_trace(model, example_inputs=torch.randn(1, 3, 32, 32))
graph = traced_model.graph
print("--- Original Graph ---")
print(graph)
# 削除/置き換え対象のノードを探す
bn_node = None
for node in graph.nodes:
if node.op == 'call_module' and node.target == 'bn':
bn_node = node
break
if bn_node:
# bn_nodeの入力(conv_nodeの出力)を取得
input_to_bn = bn_node.args[0]
# bn_nodeの出力を利用している全てのノード(relu_nodeなど)に対して、
# その入力をbn_nodeの入力(input_to_bn)に置き換える
# これにより、bn_nodeは「スキップ」される形になる
bn_node.replace_all_uses_with(input_to_bn)
# 置き換えられたノードはもう使われないので、削除する
graph.erase_node(bn_node)
# 不要になったノードをグラフから完全に削除(Optional, replace_all_uses_with 後は dead code になる)
graph.eliminate_dead_code()
# 変更を反映
traced_model.recompile()
print("\n--- Modified Graph (BN removed) ---")
print(graph)
else:
print("BatchNorm node not found.")
- bn_node.replace_all_uses_with(input_to_bn)
これが重要なステップです。bn_node
の出力(例えば、bn_node
の後にrelu_node
が続く場合、relu_node
の入力)を、bn_node
自体の入力(conv_node
の出力)に置き換えます。これにより、relu_node
はconv_node
の出力を直接受け取るようになり、bn_node
は計算グラフから切り離されます。 - graph.erase_node(bn_node)
bn_node
がもはやグラフのどの部分からも参照されていない場合、このメソッドでグラフからノードを物理的に削除します。 - graph.eliminate_dead_code()
これは、どこからも参照されていない(dead code)ノードをグラフから自動的に削除する便利な関数です。手動でerase_node
を呼び出す代わりに、最後にこれを実行するだけでも良い場合が多いです。 - traced_model.recompile()
グラフの変更を反映させるために、再度recompile()
を呼び出します。
torch.compile (推奨されるモダンなアプローチ)
torch.compile
は PyTorch 2.0 で導入された、モデル高速化のための主要なツールです。これは内部的に torch.fx
(より正確には TorchDynamo
がグラフキャプチャに利用し、AOTAutograd
と TorchInductor
がコンパイルに利用) を活用していますが、ユーザーは直接 torch.fx.Graph.nodes
を操作する必要がほとんどありません。
特徴
- 様々なバックエンド
backend
引数で、TorchInductor
(デフォルト)、aot_eager
、cudagraphs
など、様々なコンパイルバックエンドを選択できます。 - 様々な最適化
演算子融合、メモリ最適化、コード生成などを自動的に適用します。 - 自動的なグラフキャプチャ
TorchDynamo
がPythonバイトコードを解析し、グラフブレイクを最小限に抑えながら計算グラフをキャプチャします。 - 高レベルAPI
モデルや関数をtorch.compile
でラップするだけで、自動的に最適化されたコードを生成します。
torch.fx.Graph.nodes を直接操作する代わりにこれを使用する理由
- 自動最適化
ユーザーが個々の最適化パスを実装する必要なく、PyTorchが提供する最先端の最適化が自動的に適用されます。 - 堅牢性
torch.fx.symbolic_trace
で問題となりがちな「グラフブレイク」を、TorchDynamo
がより効果的に処理します。 - 生産性の向上
手動でグラフを走査し、ノードを挿入・削除する手間が省けます。
例
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 1)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = MyModel()
# モデルをコンパイルするだけで最適化が適用される
compiled_model = torch.compile(model)
# 通常通り使用できる
x = torch.randn(1, 10)
output = compiled_model(x)
print(f"Compiled model output shape: {output.shape}")
# (オプション) コンパイルの詳細なログを見る
# import logging
# torch._dynamo.config.log_level = logging.INFO
# compiled_model = torch.compile(model)
# output = compiled_model(x)
torch.jit.script / torch.jit.trace (TorchScript)
torch.jit
モジュールは、PyTorchモデルをPythonインタプリタから独立した、TorchScriptと呼ばれる中間表現に変換するためのツールです。これは、モデルのデプロイメント(C++環境での実行、モバイルデバイスへの展開など)や、Pythonのオーバーヘッドを削減したい場合に特に有用です。
特徴
- シリアライズ可能
変換されたモデルはファイルに保存し、Pythonなしでロードして実行できます。 - 最適化
TorchScriptコンパイラは、演算子融合や定数畳み込みなどの最適化を適用します。 - torch.jit.trace (トレーシング)
実際の入力データを使ってモデルを実行し、その実行パスで実行された演算を記録することでグラフを構築します。データ依存の制御フローは記録されず、実行された特定のパスのみがキャプチャされます。 - torch.jit.script (スクリプティング)
Pythonコードを直接解析し、静的なグラフ表現に変換します。データ依存の制御フロー(if
文やfor
ループ)もキャプチャできます。
torch.fx.Graph.nodes との比較
- FXはより低レベルなグラフ変換の柔軟性を提供し、研究や特殊な最適化に利用されます。TorchScriptはプロダクション環境へのデプロイメントに適しています。
- TorchScriptは、PythonからTorchScript (C++ランタイムで実行可能なIR) への変換を目的としています。
torch.fx
は主にPyTorchからPyTorchへの変換(Python-to-Python transformation)に焦点を当てており、Pythonレベルでグラフを操作し、元のPyTorchモデルに変換し直します。
例
import torch
import torch.nn as nn
class MyDynamicModel(nn.Module):
def forward(self, x):
if x.sum() > 0:
return x * 2
else:
return x / 2
model = MyDynamicModel()
# Scripting: 動的な制御フローをキャプチャできる
scripted_model = torch.jit.script(model)
print("--- Scripted Model Graph ---")
print(scripted_model.graph)
print(f"Scripted model output (positive sum): {scripted_model(torch.ones(2)).sum()}")
print(f"Scripted model output (negative sum): {scripted_model(torch.full((2,), -1.0)).sum()}")
# Tracing: 特定の実行パスのみをキャプチャする(動的な制御フローは無視されるか、警告が出る)
# この例では、x.sum() > 0 が True となるパスがトレースされる
traced_model = torch.jit.trace(model, torch.ones(2))
print("\n--- Traced Model Graph ---")
print(traced_model.graph)
print(f"Traced model output (positive sum): {traced_model(torch.ones(2)).sum()}")
# 注意: traced_modelは、トレース時の条件(x.sum() > 0 が True)に基づいているため、
# 別の入力で異なる結果になる可能性がある
print(f"Traced model output (negative sum): {traced_model(torch.full((2,), -1.0)).sum()}")
torch.ao.quantization (量子化のための高レベルFX API)
PyTorchの量子化ツールキットは、モデルのサイズを削減し、推論速度を向上させるために、torch.fx
を内部的に利用しています。ユーザーは通常、torch.fx.Graph.nodes
を直接操作することなく、高レベルのAPIを使ってモデルを量子化できます。
特徴
- 様々な量子化手法
事後学習量子化 (Post Training Quantization: PTQ) や学習時量子化 (Quantization Aware Training: QAT) など、様々な手法をサポートします。 - 融合 (Fusion)
畳み込みとバッチ正規化を融合するなど、パフォーマンスを向上させるためのモジュール融合を自動的に行います。これも内部でFXのグラフ変換を利用しています。 - 自動的な挿入
観測器(Observers)や量子化/非量子化(QuantStub/DeQuantStub)の操作をグラフに自動的に挿入します。
torch.fx.Graph.nodes を直接操作する代わりにこれを使用する理由
- 手動で量子化のためのノードを挿入・削除する手間が省けます。
- 量子化は複雑なグラフ変換を伴いますが、高レベルAPIを使用することで、その詳細を気にすることなく適用できます。
例 (概念のみ、完全な量子化コードは複雑になるため省略)
import torch
import torch.nn as nn
from torch.ao.quantization import get_default_qconfig, quantize_fx, prepare_fx, convert_fx, QConfigMapping
class QuantizableModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
model = QuantizableModel().eval() # 推論モードに設定
# 量子化設定
qconfig_mapping = QConfigMapping().set_global(get_default_qconfig("fbgemm"))
# グラフモード量子化のための準備
# 内部的にtorch.fx.symbolic_traceが使われ、GraphModuleが生成される
prepared_model = prepare_fx(model, qconfig_mapping, torch.randn(1, 1, 1, 1))
# キャリブレーション(ダミーデータで実行)
# prepared_model(torch.randn(1, 1, 1, 1))
# 量子化モデルへの変換
# 内部でグラフ変換が行われ、量子化された演算が挿入される
quantized_model = convert_fx(prepared_model)
print("\n--- Quantized Model (Simplified Graph View) ---")
# 量子化されたモデルの内部構造を確認すると、Opが変更されていることがわかる
# print(quantized_model.graph) # より詳細なグラフが見られる
AOTAutograd
は torch.compile
の内部コンポーネントの一つで、順伝播と逆伝播のグラフを事前に(Ahead-Of-Time)キャプチャし、外部コンパイラとの統合を容易にするためのものです。これはtorch.fx
のノードをベースにしてグラフを表現しますが、主に微分(Autograd)の最適化に特化しています。
torch.fx.Graph.nodes との関連
- 通常、
torch.compile
を使用する場合に内部的に利用されるため、直接AOTAutograd
のAPIを操作することは稀です。 - ただし、
AOTAutograd
は自動微分の詳細を考慮したグラフを生成するため、純粋なfx.symbolic_trace
とは異なる場合があります。 AOTAutograd
もtorch.fx.GraphModule
を生成し、そのグラフに対して操作を行います。
torch.fx.Graph.nodes
を直接操作する方法は、PyTorchのグラフ変換における最も低レベルで柔軟なアプローチですが、ほとんどのユースケースでは、より高レベルで自動化された代替手段が推奨されます。
- モデルの軽量化(量子化)
torch.ao.quantization
- デプロイメント(Python非依存)
torch.jit.script
/torch.jit.trace
- 一般的なモデルの高速化
torch.compile
(最も推奨)