PyTorchの計算グラフを理解する: FXのGraphModule.print_readable()徹底解説
FXは、nn.Module
インスタンスを変換するためのツールキットで、主に以下の3つの主要なコンポーネントから構成されます。
- シンボリックトレーサー (Symbolic Tracer): Pythonのコードをシンボリックに実行し、PyTorchモデルの計算グラフ(Graph)をキャプチャします。
- 中間表現 (Intermediate Representation - IR): キャプチャされた計算グラフを、ノードのリストとして表現します。各ノードは、入力、関数の呼び出し、モジュールの呼び出し、メソッドの呼び出し、属性の取得、出力などを表します。
- Pythonコード生成 (Python Code Generation): IRから、元のモデルと同じセマンティクスを持つPythonコードを生成します。
GraphModule
は、この中間表現であるGraph
と、そこから生成されたforward
メソッドを持つtorch.nn.Module
のインスタンスです。つまり、GraphModule
は元のnn.Module
の計算ロジックを、FXが内部で操作しやすい形式(グラフ)に変換し、それを実行可能なモジュールとしてラップしたものです。
GraphModule.print_readable()
は、このGraphModule
が内部に持っている計算グラフを、Pythonのコードに近い形式で出力します。これにより、以下の情報を確認できます。
- 出力 (Output):
forward
メソッドの戻り値。 - 操作 (Operations): グラフ内で実行される各操作(関数呼び出し、モジュール呼び出し、メソッド呼び出し、属性取得など)。
- 入力 (Placeholders):
forward
メソッドに渡される引数。
例えば、以下のようなシンプルなPyTorchモデルがあったとします。
import torch
import torch.nn as nn
import torch.fx
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.rand(3, 4))
self.linear = nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param)
# モデルをシンボリックトレースしてGraphModuleを生成
m = MyModule()
graph_module = torch.fx.symbolic_trace(m)
# print_readable()でグラフを表示
graph_module.print_readable()
このprint_readable()
を実行すると、以下のような出力が得られる可能性があります(具体的な内容はモデルやPyTorchのバージョンによって異なりますが、構造は似ています)。
class GraphModule(torch.nn.Module):
def forward(self, x):
# File: <stdin>:10, code: return self.linear(x + self.param)
get_attr_param = self.param # <stdin>:9:16
add = x + get_attr_param; x = get_attr_param = None
linear = self.linear(add); add = None
return linear
この出力から、以下のことが読み取れます。
- 最終的に
linear
の結果が返されていること。 - 結果が
self.linear
モジュールに渡され、そのforward
メソッドが呼び出されていること (linear
)。 x
とself.param
が加算されていること (add
)。self.param
という属性が取得されていること (get_attr_param
)。forward
メソッドがx
という引数を取ること。GraphModule
がtorch.nn.Module
を継承していること。
このように、print_readable()
はGraphModule
が内部でどのように計算を表現しているかを、非常に分かりやすい形で可視化してくれます。これは、FXを用いてモデルの最適化、変換、デバッグを行う際に非常に役立ちます。例えば、特定の操作が期待通りにグラフ化されているか、あるいは不要な操作が挿入されていないかなどを確認するのに使われます。
一般的なエラーと問題
-
Graph Break (グラフブレイク) これは最も一般的で重要な問題です。FXトレーサーがPythonのコードを完全にトレースできず、グラフのキャプチャが中断される現象です。
print_readable()
は、グラフブレイクが発生した部分までしか表示しません。- 原因
- データ依存の制御フロー
テンソルの値に依存するif
文やfor
ループなど。例えばif x.sum() > 0:
のようなコードはグラフブレイクを引き起こします。 - サポートされていないPython組み込み関数やC関数
open()
,len()
(テンソル以外のオブジェクトに対して),id()
,print()
(ただし、print()
は通常、トレース時に無視されるか、警告が出ます)。 - テンソル以外のオブジェクトへのアクセス
テンソルではないPythonオブジェクト(リスト、辞書など)を動的に操作し、その結果が計算グラフに影響する場合。 - torch.nn.Module 以外の外部モジュールの動的な使用
FXがそのモジュールの内部実装を理解できない場合。 - torch.Size のようなテンソル以外の型の動的な操作
例えば、input.size(0)
の結果を直接計算に利用しようとすると問題になることがあります。
- データ依存の制御フロー
- 原因
-
Proxy object cannot be iterated
エラー これはグラフブレイクの一種で、FXのProxyオブジェクトがイテレート(反復処理)できないために発生します。- 原因
- Proxyオブジェクトをループ (
for p in proxy_obj:
) や、*args
,**kwargs
として関数に渡す場合。FXは、テンソル操作をシンボリックに記録するため、Pythonの通常のイテレーションのセマンティクスを直接はキャプチャできません。
- Proxyオブジェクトをループ (
- 原因
-
モジュール情報の欠落/フラット化
print_readable()
の出力を見ると、元のnn.Module
のサブモジュール名が失われ、self.linear
のような名前がlinear
のようにフラット化されていることがあります。- 原因
- FXのトレーサーは、計算ロジックをATenオペレーターレベルまで分解し、フラットなグラフとして表現します。そのため、元のモジュールの階層構造やクラス情報がそのまま保持されないことがあります。これはエラーというよりは、FXの設計上の特性です。
- 原因
-
動的シェイプの取り扱い 入力テンソルのシェイプが実行ごとに変わる(動的シェイプ)場合、FXのトレースはデフォルトで静的なシェイプを仮定します。
- 原因
- 異なるシェイプの入力で
GraphModule
を実行すると、ガード(実行時の前提条件チェック)が失敗し、再コンパイル(recompilation)が発生したり、意図しない動作になることがあります。print_readable()
自体はエラーを吐きませんが、出力されたグラフが特定の静的シェイプに特化していることが見て取れます。
- 異なるシェイプの入力で
- 原因
-
torch.nn.Module
の継承元ではないクラスのトレースtorch.nn.Module
を継承していないクラスや関数をsymbolic_trace
しようとするとエラーになることがあります。
トラブルシューティング
-
グラフブレイクへの対処
- データ依存の制御フローの排除
if
文などがテンソル値に依存している場合、可能な限り テンソル操作のみで表現できる代替手段 を探します。例えば、torch.where()
を使うことで、条件分岐をテンソル操作に変換できる場合があります。- どうしてもPythonの制御フローが必要な場合は、その部分を別の関数に切り出し、FXのトレースから除外することを検討します(ただし、これはFXの目的から逸れる可能性があります)。
- サポートされていない操作の回避
- Pythonの組み込み関数や外部ライブラリの関数でグラフブレイクが発生する場合、それらの操作をモデルの
forward
メソッドの外で行うか、FXがサポートするPyTorchのテンソル操作に置き換える ことを検討します。 torch._dynamo.logging.set_logs(graph_breaks=True)
を設定すると、グラフブレイクが発生した場所と理由が詳細にログ出力されるため、問題の特定に非常に役立ちます。
- Pythonの組み込み関数や外部ライブラリの関数でグラフブレイクが発生する場合、それらの操作をモデルの
- torch.compile の利用
PyTorch 2.0以降では、FXを内部で利用するtorch.compile
がより堅牢なコンパイラとして提供されています。torch.compile
は、グラフブレイクが発生しても、その部分だけPythonインタープリターで実行し、残りをコンパイルする「フォールバック」メカニズムを持っています。これにより、手動でFXのトレースエラーを修正する手間が省けることがあります。 - torch.fx.wrap() の使用
特定の関数がトレースされるのを防ぎ、その関数を単一のノードとしてグラフに含めたい場合に使用します。これにより、FXがその関数の内部に踏み込まず、グラフブレイクを回避できることがあります。
- データ依存の制御フローの排除
-
Proxy object cannot be iterated エラーへの対処
- 通常、FXのトレースはテンソルの「形状」や「型」に基づいてシンボリックに実行されます。
for x in tensor
のようなイテレーションは、テンソルの内容に依存するため、一般的にトレースできません。 list(proxy_object)
のようにProxyオブジェクトをPythonのリストに変換しようとすると発生します。これはFXがサポートしていません。- このような操作が必要な場合は、その部分のロジックを見直すか、
torch.fx.wrap()
を利用してその部分をトレース範囲外にするなどの工夫が必要です。
- 通常、FXのトレースはテンソルの「形状」や「型」に基づいてシンボリックに実行されます。
-
モジュール情報の欠落/フラット化
- これはエラーではなく、FXの設計思想によるものです。FXは、特定の変換(例えば、Conv-BN融合)を行うために、ATenオペレーターレベルでグラフを操作します。
- 元のモジュール構造を保持したまま変換を行いたい場合は、FXのカスタムトレーサーを実装するか、あるいはFXの目的と合致しないと判断し、別の方法を検討する必要があるかもしれません。
-
動的シェイプの取り扱い
torch.fx.symbolic_trace
を使用する際に、concrete_args
引数を使用して、特定の引数を具体的な値として扱うことで、その引数に依存する制御フローを解消できる場合があります。torch.compile
を使用している場合、デフォルトで動的シェイプをある程度自動的に処理しようとしますが、明示的にtorch.compile(model, dynamic=True)
と設定することもできます。
-
トレースできないモジュール/操作
- カスタムトレーサーの利用
torch.fx.Tracer
を継承して、is_leaf_module
メソッドをオーバーライドすることで、特定のモジュールをリーフモジュール(それ以上トレースしない単位)として扱うことができます。これにより、FXが内部を深くトレースできないカスタムモジュールや外部ライブラリのモジュールを組み込む際に役立ちます。 - 問題の再現と最小化
複雑なモデルで問題が発生した場合、問題を再現できる最小限のコード(最小再現例)を作成することが、デバッグの第一歩です。これにより、問題の原因を絞り込みやすくなります。
- カスタムトレーサーの利用
- ドキュメントの参照
PyTorch FXの公式ドキュメント(特に "Limitations oftorch.fx.symbolic_trace
API" や "Graph Breaks" のセクション)は、詳細な情報と解決策を提供しています。 - print_readable() の反復的な利用
モデルを少しずつFXでトレースし、print_readable()
でグラフの出力を見て、どこでグラフブレイクが発生しているか、あるいはどこで意図しない操作が記録されているかを特定します。 - torch._dynamo.logging.set_logs(graph_breaks=True)
torch.compile
を使っている場合に、グラフブレイクの詳細な理由と場所を特定するのに非常に役立ちます。
基本的な使用例
まず、最も基本的な使い方から始めましょう。
import torch
import torch.nn as nn
import torch.fx
# 1. シンプルなPyTorchモデルの定義
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# 2. モデルのインスタンス化
model = SimpleModule()
# 3. symbolic_trace を使用してモデルをトレースし、GraphModule を生成
# トレース時にはダミーの入力テンソルが必要です。
dummy_input = torch.randn(1, 10)
graph_module = torch.fx.symbolic_trace(model, concrete_args={'x': dummy_input})
# 4. print_readable() を呼び出してグラフを出力
print("--- Basic Example: SimpleModule ---")
graph_module.print_readable()
print("\n--- Raw Graph Output ---")
print(graph_module.graph) # print_readable() よりも低レベルなグラフ表現
出力の解説
print_readable()
の出力は、元の forward
メソッドのコードに似た形式で、各操作がどのように実行されるかを示します。
--- Basic Example: SimpleModule ---
class GraphModule(torch.nn.Module):
def forward(self, x):
# File: <stdin>:11, code: x = self.linear1(x)
linear1 = self.linear1(x) # <stdin>:9:16
# File: <stdin>:12, code: x = self.relu(x)
relu = self.relu(linear1); linear1 = None
# File: <stdin>:13, code: x = self.linear2(x)
linear2 = self.linear2(relu); relu = None
return linear2
- 各行の末尾に
xxx = None
のようなものが見られるのは、PythonのGCがオブジェクトを解放するタイミングをシミュレートし、メモリ使用量を減らすためのFXの内部最適化です。 linear1 = self.linear1(x)
:self.linear1
サブモジュールのforward
メソッドが呼び出されていることを示します。# File: <stdin>:11, code: x = self.linear1(x)
: 元のコードのどの行がこの操作に対応するかを示します。def forward(self, x):
: トレースされたforward
メソッドのシグネチャです。class GraphModule(torch.nn.Module):
: FXによって生成されたクラスがtorch.nn.Module
を継承していることを示します。
print(graph_module.graph)
は、print_readable()
よりも低レベルな、ノードのリストとしてのグラフ表現を出力します。
--- Raw Graph Output ---
graph():
%x : [#users=1] = placeholder[target=x]
%linear1 : [#users=1] = call_module[target=linear1](args = (%x,))
%relu : [#users=1] = call_module[target=relu](args = (%linear1,))
%linear2 : [#users=1] = call_module[target=linear2](args = (%relu,))
return %linear2
グラフブレイクの例とトラブルシューティング
print_readable()
が出力するグラフが途中で切れてしまう場合、それは「グラフブレイク」が発生していることを意味します。
import torch
import torch.nn as nn
import torch.fx
# グラフブレイクを引き起こすモデルの例
class GraphBreakModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
def forward(self, x):
# グラフブレイクを引き起こすデータ依存の制御フロー
# テンソルの値に依存する if 文は、シンボリックトレースできません
if x.mean() > 0.5: # ここでグラフブレイクが発生する可能性が高い
x = self.conv(x)
else:
x = -x # この部分はトレースされない可能性が高い
return x
model_break = GraphBreakModule()
dummy_input_break = torch.randn(1, 3, 32, 32)
try:
graph_module_break = torch.fx.symbolic_trace(model_break, concrete_args={'x': dummy_input_break})
print("\n--- Graph Break Example: GraphBreakModule (might be incomplete) ---")
graph_module_break.print_readable()
except torch.fx.subgraph_rewriter.GraphBreakingException as e:
print(f"\n--- Graph Break Example: GraphBreakModule (Error Caught) ---")
print(f"Graph Breaking Exception: {e}")
print("FXはデータ依存の制御フローをトレースできません。")
print("print_readable() は完全なグラフを表示できません。")
except Exception as e:
print(f"\n--- Graph Break Example: GraphBreakModule (General Error) ---")
print(f"An unexpected error occurred during tracing: {e}")
# トラブルシューティング: torch.where を使用してグラフブレイクを回避
class NoGraphBreakModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
def forward(self, x):
# テンソル操作のみで条件分岐を表現
condition = x.mean() > 0.5
x_conv = self.conv(x)
x_neg = -x
# torch.where はテンソル操作なのでトレース可能
x_out = torch.where(condition, x_conv, x_neg)
return x_out
model_no_break = NoGraphBreakModule()
graph_module_no_break = torch.fx.symbolic_trace(model_no_break, concrete_args={'x': dummy_input_break})
print("\n--- Troubleshooting Example: NoGraphBreakModule (using torch.where) ---")
graph_module_no_break.print_readable()
出力の解説とトラブルシューティング
GraphBreakModule
では、if x.mean() > 0.5:
の行でグラフブレイクが発生するため、print_readable()
は完全なグラフを表示できません。実際には torch.fx.symbolic_trace
がエラーを投げるか、部分的なグラフしか生成しません。
NoGraphBreakModule
では、torch.where
を使ってテンソル操作だけで条件分岐を表現しています。これにより、FXはモデル全体をシンボリックにトレースでき、print_readable()
は完全なグラフを表示します。
--- Troubleshooting Example: NoGraphBreakModule (using torch.where) ---
class GraphModule(torch.nn.Module):
def forward(self, x):
# File: <stdin>:53, code: condition = x.mean() > 0.5
mean = x.mean()
gt = mean > 0.5; mean = None
# File: <stdin>:54, code: x_conv = self.conv(x)
conv = self.conv(x)
# File: <stdin>:55, code: x_neg = -x
neg = -x
# File: <stdin>:57, code: x_out = torch.where(condition, x_conv, x_neg)
where = torch.where(gt, conv, neg); gt = conv = neg = None
return where
この出力から、x.mean()
, >
(gt), self.conv(x)
, -x
, torch.where
といった各操作が順序立ててトレースされていることがわかります。
PyTorch 2.0 以降では、FX を内部で利用する torch.compile
がより推奨される方法です。torch.compile
はグラフブレイクを自動的にフォールバックで処理してくれるため、手動でのFXトレースよりも堅牢です。print_readable()
は torch.compile
によって生成された GraphModule
でも使用できます。
import torch
import torch.nn as nn
import torch.fx
# torch.compile で使用するモデル
class CompiledModule(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(100, 32)
self.lstm = nn.LSTM(32, 64, batch_first=True)
self.fc = nn.Linear(64, 10)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = self.fc(x[:, -1, :]) # LSTMの最後のシーケンス要素を取る
return x
model_compiled = CompiledModule()
# torch.compile でモデルをコンパイル
# mode='reduce-overhead' や 'max-autotune' なども指定できます
compiled_model = torch.compile(model_compiled)
# ダミー入力 (LSTMなのでシーケンスデータ)
dummy_input_compiled = torch.randint(0, 100, (2, 5)) # batch_size=2, seq_len=5
# compiled_model の forward を一度実行して、グラフを実際にコンパイルさせる
# print_readable() を呼ぶ前に一度実行することが重要です
_ = compiled_model(dummy_input_compiled)
# torch.compile は内部的に GraphModule を生成しています。
# その GraphModule にアクセスするには、通常 `_orig_mod` を介します。
# ただし、torch.compile の内部構造はバージョンによって変わる可能性があります。
# 確実な方法は、print_readable() が直接使えるわけではないため、
# torch._dynamo.explain() などを使う方が一般的です。
# ここでは、あくまで compiled_model の内部で生成されたグラフをイメージするために、
# 仮にアクセスできるものとして例を記述します。
# 注意: 以下のコードは、compiled_model の内部構造に直接アクセスしようとするもので、
# PyTorchのバージョンによっては動作しない可能性があります。
# デバッグ目的であれば、torch._dynamo.explain() を推奨します。
if hasattr(compiled_model, '_torch_dynamo_module') and hasattr(compiled_model._torch_dynamo_module, 'nn_module'):
# torch._dynamo.DynamoModule の内部の GraphModule を取得しようとする試み
# これは PyTorch の内部実装に依存するため、将来的に変更される可能性があります
# 一般的なユースケースではないことに注意してください
internal_graph_module = compiled_model._torch_dynamo_module.nn_module
print("\n--- Compiled Model Internal Graph (if accessible) ---")
internal_graph_module.print_readable()
else:
print("\n--- Compiled Model Internal Graph ---")
print("Note: Accessing the internal GraphModule of a compiled model directly")
print(" is not officially supported and may not work across all PyTorch versions.")
print(" Use torch._dynamo.explain() for debugging compiled models.")
# compiled_model のデバッグには torch._dynamo.explain() が便利
print("\n--- Debugging compiled_model with torch._dynamo.explain() ---")
# explain() は内部的に GraphModule をトレースし、その結果を出力します
# これは print_readable() と似た情報を提供します
# ただし、実行時の詳細な情報も含まれます
torch._dynamo.explain(compiled_model, dummy_input_compiled)
torch._dynamo.explain() について
torch.compile
でコンパイルされたモデルの内部構造を理解するには、print_readable()
を直接呼び出すよりも、torch._dynamo.explain()
を使用する方が一般的で安定しています。これは、コンパイルプロセス中に生成されたグラフ、適用された最適化、発生した可能性のあるグラフブレイクなど、より詳細な情報を提供します。
torch.fx.GraphModule.print_readable()
は、PyTorch FX を使用してモデルの計算グラフをデバッグする際に非常に強力なツールです。
- torch.compile との連携
torch.compile
もFXを内部で利用していますが、コンパイルされたモデルのデバッグには、print_readable()
を直接使うよりもtorch._dynamo.explain()
の方が推奨されます。 - トラブルシューティング
グラフブレイクを避けるためには、Pythonの制御フローをtorch.where
のようなテンソル操作に置き換えることを検討します。 - グラフブレイクの特定
print_readable()
の出力が途中で終わっている場合、それはデータ依存の制御フローなど、FXがトレースできない構造が原因である可能性が高いです。 - 基本的な使い方
torch.fx.symbolic_trace()
でGraphModule
を作成した後、そのインスタンスに対してprint_readable()
を呼び出すだけです。
print(graph_module.graph) または repr(graph_module.graph)
これは最も直接的な代替手段であり、print_readable()
の基盤となるものです。GraphModule
オブジェクトが保持する生の Graph
オブジェクトを直接出力します。
- 用途
- FXグラフの構造を厳密に確認したい場合。
- カスタムのFXパスを開発する際に、各ノードのプロパティをプログラム的に操作したい場合。
- 欠点
print_readable()
ほど人間が読みやすい形式ではありません。特に複雑なグラフでは、行数が多くなり、視覚的に追うのが難しい場合があります。- 元のPythonコードとの対応関係は示されません。
- 利点
- FXグラフの内部表現(ノードのリスト)を直接見ることができます。
- 各ノードの
target
、args
、kwargs
、op
(操作タイプ:placeholder
,call_function
,call_module
,call_method
,get_attr
,output
) が明確に示されます。 print_readable()
が行っている「整形」をしないため、より低レベルでの確認が可能です。
例
import torch
import torch.nn as nn
import torch.fx
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
def forward(self, x):
return self.linear1(x)
model = SimpleModule()
dummy_input = torch.randn(1, 10)
graph_module = torch.fx.symbolic_trace(model, concrete_args={'x': dummy_input})
print("--- Using print(graph_module.graph) ---")
print(graph_module.graph)
graph_module.graph.dump_python_source()
このメソッドは、FXグラフから対応するPythonコードを生成して返します。これは、FXが最終的にコンパイルして実行可能なモジュールを生成する際に使用するコードと非常に似ています。
- 用途
- FXがどのようにPythonコードを生成しているかを確認したい場合。
- 生成されたコードに問題がないかを検証したい場合。
- 欠点
print_readable()
のように、元のコード行へのコメントは通常含まれません。- 生成されるコードは最適化の結果であり、元のコードとは必ずしも一致しない場合があります。
- 利点
print_readable()
よりも、実際に生成されるコードに近い形式で見ることができます。- FXがどのようにして元のモデルのロジックを再構築しているかを理解するのに役立ちます。
- 生成されたコードをコピーして、スタンドアロンで実行し、FXの出力が期待通りかテストすることも可能です。
例
import torch
import torch.nn as nn
import torch.fx
class SimpleModule(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
def forward(self, x):
return self.linear1(x)
model = SimpleModule()
dummy_input = torch.randn(1, 10)
graph_module = torch.fx.symbolic_trace(model, concrete_args={'x': dummy_input})
print("--- Using graph_module.graph.dump_python_source() ---")
print(graph_module.graph.dump_python_source())
可視化ツール (Graphviz など)
テキストベースの出力だけでなく、グラフを画像として可視化することも可能です。これにより、特に複雑なモデルでは、データの流れや依存関係を直感的に把握できます。
- 用途
- モデルの構造を視覚的にデバッグしたい場合。
- プレゼンテーションやドキュメントにグラフを含めたい場合。
- 欠点
- Graphviz がシステムにインストールされている必要があります。
- セットアップに少し手間がかかります。
- 利点
- ノード間の接続(データの流れ)を視覚的に表現できます。
- 大規模なグラフでも、構造を一目で把握しやすくなります。
- PDF、PNG などの形式で保存できます。
- torch.fx.passes.graph_drawer.FxGraphDrawer
PyTorch FX自体に、Graphviz を使用してグラフを可視化するツールが組み込まれています。
例 (Graphviz のインストールが必要です)
import torch
import torch.nn as nn
import torch.fx
from torch.fx.passes.graph_drawer import FxGraphDrawer
class ComplexModule(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 3, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(6, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(16 * 8 * 8, 10) # 例として画像サイズ32x32を想定
def forward(self, x):
x = self.pool(self.relu1(self.conv1(x)))
x = self.pool(self.relu1(self.conv2(x)))
x = torch.flatten(x, 1)
x = self.fc(x)
return x
model = ComplexModule()
dummy_input = torch.randn(1, 3, 32, 32)
graph_module = torch.fx.symbolic_trace(model, concrete_args={'x': dummy_input})
print("\n--- Generating Graphviz visualization ---")
# Graphviz インストール済みであることを前提
# draw_graph(graph_module.graph, 'graph.png', 'MyFXGraph')
# 上記の行は、実行すると 'graph.png' というファイルが生成されます。
# 動作確認のためには、以下のようにインスタンス化して描画メソッドを呼び出す必要があります。
try:
g = FxGraphDrawer(graph_module, "MyFXGraph")
# 'graph.png' という名前でPNGファイルを生成
file_path = "fx_graph_visualization.png"
g.get_dot_graph().write_png(file_path)
print(f"Graph visualization saved to {file_path}")
print("Graphviz がインストールされていない場合、この部分はエラーになるか動作しません。")
except Exception as e:
print(f"Error during Graphviz drawing: {e}")
print("Graphviz がインストールされていることを確認してください。")
print("例: sudo apt-get install graphviz (Ubuntu)")
print(" brew install graphviz (macOS)")
torch._dynamo.explain() (PyTorch 2.0 以降)
torch.compile
でコンパイルされたモデルのデバッグに特化していますが、内部でFXを利用しており、非常に詳細な情報を提供します。
- 用途
torch.compile
でモデルを最適化している場合に、その動作を理解し、デバッグしたい場合。- パフォーマンスのボトルネックやグラフブレイクの原因を特定したい場合。
- 欠点
torch.compile
を使用していないFXグラフのデバッグには直接使えません。
- 利点
- コンパイルされたグラフの生成過程、適用された最適化、発生したフォールバック(グラフブレイク)など、包括的な情報を提供します。
print_readable()
のようなグラフのテキスト表現に加えて、バックエンドに関する情報も含まれます。torch.compile
を利用している場合に、最も推奨されるデバッグツールです。
例
import torch
import torch.nn as nn
import torch.fx
import torch._dynamo
class CompiledExampleModule(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
# 意図的にグラフブレイクを発生させる可能性のある操作
if self.param.item() > 0.5: # .item() はグラフブレイクを引き起こす可能性
return x + self.param
return x * 2
model = CompiledExampleModule()
compiled_model = torch.compile(model)
dummy_input = torch.randn(5)
print("\n--- Using torch._dynamo.explain() ---")
# explain() を呼び出すと、詳細なログが出力されます
torch._dynamo.explain(compiled_model, dummy_input)
- torch.compile 使用時の包括的なデバッグ
torch._dynamo.explain()
torch.fx.GraphModule.print_readable()
は、PyTorch FX のGraphModule
の内部表現をテキスト形式で出力する際に非常に便利ですが、他にも計算グラフを理解・可視化するための様々な方法があります。目的に応じて、以下の代替手段を検討すると良いでしょう。 - 視覚的なグラフ表示
torch.fx.passes.graph_drawer.FxGraphDrawer
(Graphviz が必要) - 生成されるPythonコード
graph_module.graph.dump_python_source()
- 生のFXノードリスト
print(graph_module.graph)
torch.fx.Graph オブジェクトを直接操作する
GraphModule
は内部に Graph
オブジェクトを持っています。この Graph
オブジェクトは、ノードのリスト(graph_module.graph.nodes
)として表現されており、各ノードは操作(関数呼び出し、モジュール呼び出し、属性取得など)を表します。
print_readable()
はこの Graph
オブジェクトを整形して表示していますが、デバッグやカスタム解析のためには、直接 Graph
オブジェクトを操作することもできます。
import torch
import torch.nn as nn
import torch.fx
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
model = MyModule()
dummy_input = torch.randn(1, 3, 32, 32)
graph_module = torch.fx.symbolic_trace(model, concrete_args={'x': dummy_input})
print("--- Iterating through Graph Nodes ---")
for node in graph_module.graph.nodes:
print(f"Node: {node.name}, Op: {node.op}, Target: {node.target}, Args: {node.args}, Kwargs: {node.kwargs}")
print("\n--- Accessing specific node attributes ---")
# 例えば、最初の呼び出しノード (placeholder は除く)
call_nodes = [n for n in graph_module.graph.nodes if n.op == 'call_module']
if call_nodes:
first_call_node = call_nodes[0]
print(f"First call module target: {first_call_node.target}")
# 対応するサブモジュールも取得できる
if hasattr(graph_module, first_call_node.target):
print(f"Corresponding submodule: {getattr(graph_module, first_call_node.target)}")
利点
- カスタムのグラフ変換や分析ツールを構築する際に基盤となる。
- より詳細なレベルでグラフ構造にアクセスし、プログラムで解析・操作できる。
欠点
- 人間が直接読むには冗長で、グラフの全体像を掴みにくい。
グラフィカルな可視化ツール (Graph Visualization Tools)
テキスト出力だけでなく、視覚的にグラフを表現することで、モデルの構造やデータフローを直感的に理解しやすくなります。
-
torchview (外部ライブラリ)
torchview
は、nn.Module
の構造を視覚的に表示することに特化したライブラリです。torch.fx
と互換性があり、FXがトレースできるモデルであれば、その構造を綺麗に描画できます。# pip install torchview from torchview import draw_graph import torch import torch.nn as nn class BranchingModule(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(10, 20) self.linear2 = nn.Linear(20, 30) self.linear3 = nn.Linear(20, 5) def forward(self, x): x = self.linear1(x) y1 = self.linear2(x) y2 = self.linear3(x) return y1 + y2 # 複数のパスが合流 model = BranchingModule() dummy_input = torch.randn(1, 10) # expand_nested=True でサブモジュールを詳細に展開 model_graph = draw_graph(model, input_size=(1, 10), expand_nested=True) # グラフを保存 (requires graphviz) model_graph.visual_graph.render("branching_module_graph", format="png", cleanup=True) print("torchview diagram saved as 'branching_module_graph.png'") # Jupyter Notebook などでは、model_graph.visual_graph を直接表示できる
利点
- FX互換のグラフを、サブモジュールの階層構造を考慮して綺麗に可視化できる。
- 入力シェイプやデータ型も表示可能。
欠点
- 外部ライブラリのインストールが必要。Graphviz も必要。
-
torchviz (外部ライブラリ)
torchviz
は、PyTorchの計算グラフ(autograd graph)をGraphvizを使って可視化する外部ライブラリです。これはフォワードパスだけでなく、バックワードパスのオペレーションも可視化できる点が特徴です。import torch import torch.nn as nn from torchviz import make_dot class MyComplexModule(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 16, 3) self.pool = nn.MaxPool2d(2) self.fc = nn.Linear(16 * 14 * 14, 10) # 例としてサイズを仮定 def forward(self, x): x = self.conv(x) x = self.pool(x) x = x.view(x.size(0), -1) # Flatten x = self.fc(x) return x model = MyComplexModule() dummy_input = torch.randn(1, 3, 30, 30) # convとpoolに適した入力サイズ # 順伝播を実行して出力テンソルを得る output = model(dummy_input) # make_dot を使用してグラフを生成 # params を渡すと、モデルのパラメータもノードとして表示される graph = make_dot(output, params=dict(model.named_parameters())) # グラフをファイルに保存 (Graphvizがインストールされている必要があります) graph.render("my_complex_module_graph", format="png", cleanup=True) print("Graphviz diagram saved as 'my_complex_module_graph.png'")
- 計算の流れが視覚的にわかりやすい。
- TensorBoardは大規模モデルで重くなりがちだが、
torchviz
は特定の出力テンソルからのパスを可視化できる。 欠点: - FXのIRとは異なる(autograd graph)。
- Graphviz のインストールが必要。
-
TensorBoardの add_graph
PyTorchはTensorBoardとの連携機能を持っており、SummaryWriter.add_graph()
を使うことでモデルの計算グラフを可視化できます。これはFXグラフというよりは、PyTorchのオートグラッドグラフに近いですが、一般的なモデルの構造把握には非常に有効です。from torch.utils.tensorboard import SummaryWriter import torch import torch.nn as nn class SimpleNet(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(10, 50) self.relu = nn.ReLU() self.fc2 = nn.Linear(50, 1) def forward(self, x): return self.fc2(self.relu(self.fc1(x))) model = SimpleNet() dummy_input = torch.randn(1, 10) writer = SummaryWriter('runs/simple_net_graph') writer.add_graph(model, dummy_input) writer.close() print("TensorBoard graph logged to 'runs/simple_net_graph'.") print("Run: tensorboard --logdir runs")
実行後、ターミナルで
tensorboard --logdir runs
を実行し、ブラウザで表示されるURLにアクセスすると、"Graphs" タブでモデルのグラフを見ることができます。
torch.compile
を使用している場合、torch._dynamo.explain()
は、コンパイルプロセス中に生成されたFXグラフ、適用された最適化、発生したグラフブレイクなど、詳細なデバッグ情報を提供します。これは print_readable()
の上位互換のようなものです。
import torch
import torch.nn as nn
import torch.fx
class ComplexBranch(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.rand(10))
self.linear = nn.Linear(10, 10)
def forward(self, x):
if x.sum() > self.param.sum(): # グラフブレイクの可能性
return self.linear(x)
else:
return x + self.param
model = ComplexBranch()
dummy_input = torch.randn(10)
# TORCH_LOGS="graph_breaks" 環境変数を設定すると、
# グラフブレイクに関する詳細なログが出力されます。
# Pythonコード内で設定するには以下のようにします:
import os
os.environ["TORCH_LOGS"] = "graph_breaks"
print("--- Using torch._dynamo.explain() for compiled model ---")
compiled_model = torch.compile(model)
try:
_ = compiled_model(dummy_input)
except Exception as e:
print(f"Execution with compiled_model might fail due to graph break: {e}")
# explain() を呼び出すことで、コンパイル過程の詳細なレポートが得られる
# これには、内部のFXグラフの表現も含まれる
torch._dynamo.explain(model, dummy_input)
利点
- FXグラフだけでなく、生成された低レベルコード(Triton/CUDAなど)へのポインタも提供する場合がある。
- 最適化後の内部グラフや、なぜ特定のコードがコンパイルできないのかといった洞察が得られる。
torch.compile
で発生する問題(グラフブレイク、ガード失敗など)のデバッグに特化しており、非常に詳細な情報を提供する。
欠点
- 出力が非常に詳細であるため、慣れるまでに時間がかかる。
- PyTorch 2.0 以降の
torch.compile
と密接に結びついているため、純粋なFXのデバッグにはややオーバーキルかもしれない。
GraphModule.code の参照
GraphModule
は、内部で生成されたPythonコードを code
プロパティとして保持しています。これは print_readable()
と似た形式ですが、文字列として直接アクセスできます。
import torch
import torch.nn as nn
import torch.fx
class AnotherModule(nn.Module):
def forward(self, x, y):
a = x * 2
b = y + 1
return a - b
model = AnotherModule()
dummy_x = torch.randn(5)
dummy_y = torch.randn(5)
graph_module = torch.fx.symbolic_trace(model, concrete_args={'x': dummy_x, 'y': dummy_y})
print("--- Accessing GraphModule.code ---")
print(graph_module.code)
利点
- ファイルへの保存や、文字列操作による解析が容易。
print_readable()
とほぼ同じ出力内容を、文字列としてプログラム的に取得できる。
欠点
print_readable()
と比較して、フォーマットが固定されており、直接的な整形オプションがない。
- 生成されたコードを文字列として取得し、さらに加工したい場合
graph_module.code
プロパティを使用します。 - torch.compile でのパフォーマンス問題やグラフブレイクのデバッグ
torch._dynamo.explain()
が最も包括的で推奨されるツールです。 - モデルの構造を視覚的に理解したい場合
torchviz
やtorchview
(FXグラフに近い描画)、あるいはTensorBoard
(Autogradグラフ) が役立ちます。特に複雑な分岐や結合を持つモデルには可視化が有効です。 - より詳細なグラフ構造のプログラム的な分析
graph_module.graph.nodes
を直接イテレートする方法が適しています。 - 簡単なモデルの動作確認やFXの学習
torch.fx.GraphModule.print_readable()
が最も手軽で分かりやすいです。