PyTorch FXデバッグ術:create_args_for_root()関連エラーの対処法
PyTorch の torch.fx
は、nn.Module
のインスタンスを変換するためのツールキットであり、主に以下の3つの主要コンポーネントで構成されています。
- Symbolic Tracer (シンボリックトレーサー): Pythonコードの「シンボリック実行」を行います。実際の値ではなく「Proxy」と呼ばれる偽の値をコードに投入し、Proxyに対する操作を記録します。
- Intermediate Representation (中間表現): シンボリックトレース中に記録された操作を格納するためのコンテナです。Graphと呼ばれるデータ構造で、関数入力、関数呼び出し、メソッド呼び出し、
nn.Module
インスタンスへの呼び出し、戻り値を表すNodeのリストで構成されます。 - Python Code Generation (Pythonコード生成): Graphから有効なPythonコードを生成します。
torch.fx.Tracer.create_args_for_root()
は、このシンボリックトレースの過程で、トレース対象となる nn.Module
の forward
メソッドに渡す「引数」を作成する役割を担います。
具体的には、以下の処理を行います。
Tracer
はこれらの Proxy 引数を用いてforward
関数を呼び出し、Proxyがプログラムを流れるにつれて、それらが触れたすべての操作(torch
関数の呼び出し、メソッド呼び出し、演算子)をFX GraphにNodeとして記録していきます。- これらの Proxy オブジェクトは、実際のテンソルではなく、トレース中に実行される操作を記録するための「代理」として機能します。
forward
メソッドの引数に対応する Proxy オブジェクトを作成します。
要するに、create_args_for_root()
は、トレースを開始するために必要なダミーの入力(Proxyオブジェクト)を生成し、nn.Module
の forward
メソッドがそれらを使ってどのように計算を進めるかを fx.Graph
として記録できるようにする、初期設定のような役割を果たすメソッドです。
TraceError: symbolically traced variables cannot be used as inputs to control flow (シンボリックトレースされた変数を制御フローの入力として使用できません)
エラーの原因
このエラーは、forward
メソッド内でデータ依存の制御フロー(if
文やfor
ループなど)が存在する場合によく発生します。torch.fx
は、コードのシンボリックな実行を行い、計算グラフを構築します。しかし、Tracer.create_args_for_root()
が生成するProxyオブジェクトは、実際の値ではなく、その値を介した操作を記録するためのものです。そのため、Proxyオブジェクトが条件分岐やループのイテレーション回数を決定するようなデータ依存の制御フローに使われると、トレーサーは具体的な値を評価できず、グラフを構築できなくなります。
トラブルシューティング
- カスタムのトレーサーを実装する
- 複雑なケースでは、
torch.fx.Tracer
をサブクラス化し、create_args_for_root()
やis_leaf_module()
などのメソッドをオーバーライドして、特定のモジュールや関数のトレース方法をカスタマイズする必要があります。これにより、トレースが難しい部分をスキップしたり、特定の方法で処理したりできます。
- 複雑なケースでは、
- 型ヒントの活用
- 特にリストやタプルなどの集合型を入力として受け取る場合、
forward
メソッドの引数に明示的な型ヒント(例:x: List[torch.Tensor]
)を追加することで、トレーサーが引数の構造を理解しやすくなる場合があります。
- 特にリストやタプルなどの集合型を入力として受け取る場合、
- データ依存の制御フローを避ける
- 可能な限り、モデルの構造が入力データに依存しないように設計します。
- 例えば、バッチサイズやシーケンス長が可変の場合でも、
torch.jit.trace
やtorch.compile
の動的シェイプ機能を使うか、固定のダミー入力でトレースを試みます。 - どうしてもデータ依存の制御フローが必要な場合は、その部分をトレースの範囲外にするか、
torch.fx.wrap()
やtorch.fx.wrap_method()
を使用して、その関数を「リーフノード」(内部をトレースせず、単一の操作として記録する)として扱うことを検討します。
TypeError: cat() received an invalid combination of arguments - got (Proxy, int), but expected one of: ... (引数の無効な組み合わせ)
トラブルシューティング
- fx.wrap() の利用
- もし、特定の関数(例:
torch.cat
)の呼び出し方がFXのデフォルトのトレーサーでは正しく扱えない場合、その関数をfx.wrap()
でラップすることで、その関数呼び出しが単一のノードとしてグラフに記録され、内部の引数の問題を回避できることがあります。ただし、これはその関数の最適化機会を失う可能性があります。
- もし、特定の関数(例:
- 明示的な入力構造
torch.fx.symbolic_trace()
には、args
引数を使って具体的なダミー入力(テンソルのリストやタプルなど)を渡すことができます。create_args_for_root()
が自動で生成する引数が不適切であれば、ここで明示的に正しい構造のダミー入力を与えることで問題を回避できます。
import torch import torch.nn as nn import torch.fx class MyCatModule(nn.Module): def forward(self, xs): # xs is expected to be a list of tensors return torch.cat(xs, dim=0) model = MyCatModule() # 正しい引数構造を渡す dummy_input = [torch.randn(2, 3), torch.randn(2, 3)] traced_model = torch.fx.symbolic_trace(model, concrete_args={'xs': dummy_input}) print(traced_model.graph)
AttributeError: 'Proxy' object has no attribute 'xxx' (Proxyオブジェクトに属性'xxx'がありません)
エラーの原因
このエラーは、モデル内でProxyオブジェクトが、実際のテンソルが持つべきではない属性(例: .item()
, .numpy()
, .shape
を直接使ってPythonの整数やリストを生成するなど)にアクセスしようとした場合に発生します。Proxyオブジェクトは、トレースのために必要な情報を保持しているだけで、実際のテンソルと同じすべての操作をサポートするわけではありません。
トラブルシューティング
- torch.fx.wrap() または torch.fx.map_arg() の使用
- もし、どうしてもテンソルの値をPythonのロジックに利用する必要がある場合、その部分を
fx.wrap()
でラップして「リーフノード」として扱うか、より高度なシナリオではtorch.fx.map_arg()
を使用してProxyを実際の値に変換してから操作し、その後再びProxyに戻すなどのカスタムロジックを実装する必要があります。ただし、これはFXの目的(グラフ最適化)に反する可能性があり、注意が必要です。
- もし、どうしてもテンソルの値をPythonのロジックに利用する必要がある場合、その部分を
- テンソル操作のみを使用する
- モデルの
forward
メソッド内では、PyTorchのテンソル操作(torch.Tensor
メソッドやtorch
関数)のみを使用するようにします。 - テンソルの値を直接取得してPythonの標準的な制御フローやデータ構造に使用することは避けてください。
- モデルの
動的なシェイプ、型、または値の変化
エラーの原因
create_args_for_root()
は、トレース開始時に固定されたダミー入力(Proxy)を生成します。もしモデルが、入力のシェイプ、データ型、または値によって内部の計算グラフが変化するような動的な振る舞いをする場合、FXはこれを適切にトレースできません。たとえば、入力のシェイプに基づいてレイヤーの構造を変えたり、入力の値によって異なるパスを実行したりするモデルです。
- concrete_args を使用して複数のパスをトレース
- 特定の動的な挙動が、少数の明確な入力シェイプの組み合わせに限定される場合、それぞれの入力シェイプに対して個別にトレースを行い、複数のGraphModuleを生成する方法も考えられます。
- torch.compileの検討
- PyTorch 2.0以降では、
torch.compile
がFXを内部的に利用しており、動的なシェイプや一部の制御フローに対してより堅牢なトレース(グラフ分割など)を提供します。FXの直接利用が難しい場合は、torch.compile
の使用を検討するのが良いでしょう。
- PyTorch 2.0以降では、
- 静的なグラフ設計
- FXによるトレースは、静的な計算グラフを想定しています。モデルが動的な振る舞いをする場合、FXの適用は困難です。可能な限り、入力のシェイプや値に依存しない静的な計算パスを持つようにモデルを設計します。
- torch.compile のトラブルシューティングガイドも参照する
torch.compile
はFXをベースにしているため、torch.compile
のトラブルシューティングガイド(特に「Graph breaks」のセクション)もFXのデバッグに役立つ情報を含んでいます。 - FX Graphを検査する
エラーが発生した場合、traced_model.graph.print_tabular()
やprint(traced_model.graph)
を使用して、生成されたグラフの構造を確認します。どこでトレースが失敗しているのか、どのようなノードが期待通りに生成されていないのかを把握するのに役立ちます。 - 最小限の再現コードを作成する
問題が発生した場合は、元のモデルから問題を再現できる最小限のコードを切り出すことが重要です。これにより、デバッグが容易になります。 - PyTorchのバージョンを確認する
torch.fx
はPyTorch 1.8で導入され、以降も活発に開発が進められています。古いバージョンを使用している場合、最新の機能やバグ修正が適用されていない可能性があります。
ここでは、create_args_for_root()
がどのように機能し、どのような文脈で重要になるのかを理解するためのいくつかの例と、関連する概念について説明します。
symbolic_trace の内部での利用 (最も一般的)
最も一般的なケースは、torch.fx.symbolic_trace()
を呼び出すことで、内部的に Tracer
がインスタンス化され、その中で create_args_for_root()
が呼び出されるパターンです。
import torch
import torch.nn as nn
import torch.fx
class SimpleModel(nn.Module):
def forward(self, x, y):
a = x + y
b = a * 2
return b - x
# モデルのインスタンス化
model = SimpleModel()
# symbolic_trace を実行
# この呼び出しの内部で、Tracer().trace(model, concrete_args=None) が実行され、
# さらにその中で create_args_for_root() が呼び出され、x と y に対応する Proxy オブジェクトが生成される。
traced_model = torch.fx.symbolic_trace(model)
print("--- Traced Graph ---")
print(traced_model.graph)
# 生成されたグラフを可視化(graphvizがインストールされている場合)
# traced_model.graph.print_tabular()
# traced_model.graph.to_dot().render("simple_model_graph", view=True)
# 実行例
dummy_x = torch.randn(3, 4)
dummy_y = torch.randn(3, 4)
output_traced = traced_model(dummy_x, dummy_y)
output_original = model(dummy_x, dummy_y)
print("\n--- Output Comparison ---")
print(f"Original output: {output_original}")
print(f"Traced output: {output_traced}")
assert torch.allclose(output_original, output_traced)
print("Outputs match!")
解説
この例では、torch.fx.symbolic_trace(model)
を呼び出すだけで、FXが自動的に SimpleModel
の forward
メソッドのシグネチャを検査し、それに対応するProxyオブジェクト(ダミーの入力)を生成します。この生成を担当しているのが、内部で呼び出される Tracer.create_args_for_root()
です。ユーザーは明示的に引数を与えていませんが、FXはデフォルトのルールに基づいてProxyを生成します。
concrete_args を使用して引数を明示的に指定する
create_args_for_root()
の動作をユーザーが部分的に制御する一般的な方法は、symbolic_trace
の concrete_args
引数を使用することです。これにより、特定の引数をProxyとしてではなく、具体的なPythonオブジェクト(テンソルや数値など)としてトレースに渡すことができます。
import torch
import torch.nn as nn
import torch.fx
class DynamicShapeModel(nn.Module):
def forward(self, x, factor: int):
# factor が具体的な int であると、この if 文はトレース時に評価される
if factor > 1:
return x * factor
else:
return x / 2
model = DynamicShapeModel()
print("--- Tracing with concrete_args for 'factor' ---")
# 'factor' を具体的な整数としてトレースに渡す
# create_args_for_root() は 'x' には Proxy を、'factor' には 3 を生成する
traced_model_concrete = torch.fx.symbolic_trace(model, concrete_args={'factor': 3})
print("Graph when factor is 3:")
print(traced_model_concrete.graph)
# expected: x * 3
# 別の factor でトレース
traced_model_concrete_alt = torch.fx.symbolic_trace(model, concrete_args={'factor': 0})
print("\nGraph when factor is 0:")
print(traced_model_concrete_alt.graph)
# expected: x / 2
# concrete_args を指定しない場合
print("\n--- Tracing without concrete_args (may fail or produce unexpected graph) ---")
try:
# この場合、factor も Proxy になり、factor > 1 の条件を評価できないため TraceError が発生する可能性が高い
traced_model_no_concrete = torch.fx.symbolic_trace(model)
print("Graph without concrete_args (if successful):")
print(traced_model_no_concrete.graph)
except torch.fx.TraceError as e:
print(f"TraceError occurred as expected: {e}")
解説
concrete_args
を使用すると、create_args_for_root()
が特定の引数に対してProxyではなく、指定された具体的な値を生成するように指示できます。これにより、その引数がデータ依存の制御フロー(例: if factor > 1
)に使われている場合でも、トレース中にその条件を評価し、特定のパスを辿る静的なグラフを生成することができます。concrete_args
の値が異なれば、生成されるグラフも異なる場合があります。
非常に特殊なケースで、Tracer
のデフォルトの引数生成ロジックが不十分な場合、Tracer
をサブクラス化し、create_args_for_root()
メソッドをオーバーライドすることができます。これは高度なユースケースであり、FXの内部動作を深く理解している必要があります。
例えば、forward
メソッドが特定のカスタムオブジェクトや、FXがデフォルトでProxyを生成できない複雑な構造を持つ引数を期待する場合に有用かもしれません。
import torch
import torch.nn as nn
import torch.fx
from torch.fx.proxy import Proxy
# カスタムのデータ構造
class CustomArgs:
def __init__(self, value):
self.value = value
class CustomTracer(torch.fx.Tracer):
def create_args_for_root(self, root_fn, script_fn_root) -> tuple:
# root_fn はトレース対象のモジュール/関数の forward メソッド
# デフォルトの create_args_for_root のロジックを呼び出す
args, kwargs = super().create_args_for_root(root_fn, script_fn_root)
# ここで、生成された引数 (Proxyなど) を検査し、必要に応じて変更する
# 例えば、特定の型の引数をカスタムの Proxy に置き換えるなど
new_args = []
for arg in args:
if isinstance(arg, Proxy):
# デフォルトの Proxy をそのまま使うか、カスタムロジックを適用
new_args.append(arg)
# 例: もし CustomArgs 型の引数が forward にあれば、
# それを特別な Proxy でラップするなど
# if isinstance(arg, CustomArgs):
# new_args.append(CustomProxy(arg.value)) # カスタムProxyの例
else:
new_args.append(arg)
# ここでは単純にデフォルトの args をそのまま返す
# 実際には、特定の引数に対して別の Proxy タイプを生成したり、
# 非テンソル引数にカスタムの処理を加えたりする
print(f"CustomTracer: created args: {args}, kwargs: {kwargs}")
return args, kwargs
class ModelWithCustomArgs(nn.Module):
def forward(self, x: torch.Tensor, config: int):
# config は通常の int なので、デフォルトで concrete_arg になる
return x * config
model = ModelWithCustomArgs()
# カスタムTracerを使用してトレース
custom_tracer = CustomTracer()
# trace メソッドに concrete_args を渡すことも可能
traced_model_custom = custom_tracer.trace(model, concrete_args={'config': 5})
print("\n--- Traced Graph with CustomTracer ---")
print(traced_model_custom.graph)
dummy_x = torch.randn(2, 2)
output_traced = traced_model_custom(dummy_x, 5)
output_original = model(dummy_x, 5)
print("\n--- Output Comparison ---")
print(f"Original output: {output_original}")
print(f"Traced output: {output_traced}")
assert torch.allclose(output_original, output_traced)
print("Outputs match!")
解説
この例では、CustomTracer
を作成し、create_args_for_root
をオーバーライドしています。ただし、このオーバーライドは単に super().create_args_for_root()
を呼び出し、生成された引数を出力するだけです。実際のユースケースでは、ここで独自のロジックを追加し、例えば特定の条件に基づいて異なる種類のProxyを生成したり、引数のデフォルト値を変更したりすることが考えられます。
- これにより、
forward
メソッドがProxyオブジェクトを受け取り、その後の演算がFXグラフとして記録される準備が整う。 - 各引数に対して、トレースに必要なProxyオブジェクト(または
concrete_args
で指定された具体的な値)を生成する。 forward
メソッドのシグネチャ(引数の名前、型ヒント)を解析する。
しかし、「create_args_for_root()
に関連するプログラミングの代替方法」という文脈では、以下の点が考えられます。
create_args_for_root()
が生成する引数の内容を制御する代替方法- FXトレース以外の、PyTorchモデルを変換・最適化する代替方法
それぞれについて詳しく説明します。
create_args_for_root() が生成する引数の内容を制御する代替方法
create_args_for_root()
は、torch.fx.symbolic_trace()
の呼び出し時に裏側で実行されます。この自動生成される引数の内容を制御するための主要な代替方法は、以下の通りです。
a. torch.fx.symbolic_trace()
の concrete_args
引数を使用する
これが最も一般的で推奨される方法です。concrete_args
を使用すると、forward
メソッドの一部の引数をProxyオブジェクトではなく、実際のPythonオブジェクト(テンソル、整数、ブール値など)としてFXトレーサーに「見せる」ことができます。これにより、その引数に依存する制御フロー(if
文など)をトレース時に評価させ、特定のパスのみをグラフに含めることができます。
使用例
import torch
import torch.nn as nn
import torch.fx
class ConditionalModel(nn.Module):
def forward(self, x: torch.Tensor, flag: bool):
if flag:
return x * 2
else:
return x / 2
model = ConditionalModel()
# flag=True のケースでトレース
# create_args_for_root は x に Proxy を、flag に True を生成する
traced_model_true = torch.fx.symbolic_trace(model, concrete_args={'flag': True})
print("--- Graph for flag=True ---")
print(traced_model_true.graph)
# グラフには 'x * 2' の部分のみが含まれる
# flag=False のケースでトレース
# create_args_for_root は x に Proxy を、flag に False を生成する
traced_model_false = torch.fx.symbolic_trace(model, concrete_args={'flag': False})
print("\n--- Graph for flag=False ---")
print(traced_model_false.graph)
# グラフには 'x / 2' の部分のみが含まれる
# 実行確認
dummy_x = torch.randn(3, 3)
assert torch.allclose(traced_model_true(dummy_x, True), model(dummy_x, True))
assert torch.allclose(traced_model_false(dummy_x, False), model(dummy_x, False))
なぜこれが代替方法なのか?
create_args_for_root()
のデフォルトの挙動は、forward
メソッドの引数から自動的にProxyを生成しようとします。しかし、concrete_args
を使うことで、この自動生成のロジックに介入し、特定の引数に対してはProxyではない具体的な値を「注入」することができます。これにより、create_args_for_root()
が生成する引数のセットと、それによって形成されるグラフの形状を制御できます。
b. モデルの forward
メソッドのシグネチャを変更する
例
import torch
import torch.nn as nn
import torch.fx
class ModelWithDefaultArg(nn.Module):
def forward(self, x: torch.Tensor, use_bias: bool = True):
# use_bias にデフォルト値があるため、concrete_args なしでもトレース可能になることが多い
if use_bias:
return x + 1.0
else:
return x
model = ModelWithDefaultArg()
# デフォルト値 (True) でトレースされる
traced_model_default = torch.fx.symbolic_trace(model)
print("--- Graph with default arg ---")
print(traced_model_default.graph)
# グラフには 'x + 1.0' の部分が含まれる
なぜこれが代替方法なのか?
create_args_for_root()
は、forward
メソッドの引数リストを調べて引数を生成します。引数にデフォルト値がある場合、FXはしばしばそのデフォルト値を concrete_arg
のように扱い、制御フローを解決しようとします。これにより、明示的に concrete_args
を指定しなくても、期待するグラフを生成できる場合があります。
c. カスタム Tracer
を作成し、create_args_for_root()
をオーバーライドする (上級者向け)
これは最も柔軟な方法ですが、最も複雑でもあります。torch.fx.Tracer
をサブクラス化し、create_args_for_root()
メソッドを直接オーバーライドすることで、引数の生成ロジックを完全にカスタマイズできます。
ユースケース
- 引数の型ヒントだけでは不十分で、追加のヒューリスティックに基づいて引数を生成したい場合。
- 特定の引数に対して、カスタムのProxyオブジェクトを生成したい場合。
- デフォルトのトレーサーでは正しく扱えない、非常に複雑なカスタムデータ構造が
forward
メソッドの引数にある場合。
注意点
この方法はFXの内部に深く入り込むため、FXのProxyシステム、GraphNode、PythonのAST(抽象構文木)に関する知識が必要になることがあります。
torch.fx.Tracer.create_args_for_root()
が関わるのはFXトレースですが、PyTorchモデルの変換や最適化には、FX以外にも様々なアプローチがあります。これらも「代替方法」と考えることができます。
a. torch.jit.trace()
/ torch.jit.script()
(TorchScript)
PyTorchにおけるモデル変換の最も古い、そして今でも広く使われている方法の一つです。
torch.jit.script()
: Pythonソースコードを直接解析し、TorchScriptの静的型システムと互換性のあるグラフを構築します。制御フローも適切に処理できます。- 利点: データ依存の制御フローも処理可能。
- 欠点: Pythonの全機能をサポートするわけではなく、TorchScriptのサブセットに限定されます。デバッグが難しい場合があります。
torch.jit.trace()
: 具体的なダミー入力を用いて、モデルの実際の実行パスを記録し、TorchScript形式のグラフを生成します。FXがソースコードを解析するのに対し、trace
は実行時動作を記録します。- 利点: 非常にシンプルで、PyTorchのほとんどの操作をサポートします。
- 欠点: データ依存の制御フロー(
if
文やfor
ループのイテレーション回数が入力データに依存する場合など)を適切に扱えません。トレース時の入力パスしか記録されません。
関連性
torch.jit.trace()
はダミー入力が必須であり、その入力が create_args_for_root()
のような役割を果たします。
b. torch.compile()
(PyTorch 2.0以降)
PyTorch 2.0で導入された最先端のコンパイラで、FXをバックエンドとして使用しています。これは、モデルの高速化と最適化を自動で行うための最も推奨される方法です。
torch.compile()
は、FXが単独で処理できないような複雑な制御フローや動的なシェイプに対しても、グラフを「分割」(graph break)することで対応しようとします。torch.compile()
は、内部でFXトレーサー(や他のコンパイラ技術)を使用してモデルのグラフを抽出し、それをTritonなどの低レベルカーネルにコンパイルします。
関連性
torch.compile
はFXを内部的に使用するため、create_args_for_root()
が行う引数生成のステップも間接的に利用されます。torch.compile
はユーザーが直接FXの内部を操作することなく、FXの恩恵を受けるためのより高レベルなインターフェースを提供します。concrete_args
と同様の機能は、torch.compile
の dynamic=True
や fullgraph=False
といったオプションで間接的に制御されます。
c. ONNX (Open Neural Network Exchange) エクスポート
PyTorchモデルをONNX形式にエクスポートし、他のフレームワーク(TensorFlow, ONNX Runtimeなど)で推論を実行するための標準的な中間表現です。
- ONNXはFXとは異なるグラフ表現を持ち、特定のONNXオペレータに変換されます。
torch.onnx.export()
を使用します。これもモデルのforward
メソッドにダミー入力が必要です。
関連性
torch.onnx.export()
もモデルの変換に際してダミー入力が必須であり、その点で create_args_for_root()
が果たす役割(トレース開始のための入力準備)と概念的に似ています。
d. カスタムなグラフ変換ツール
ごく稀に、特定の最適化や変換のために、FXやTorchScriptに依存しない独自のツールやスクリプトを開発することもあります。例えば、モデルの層を直接操作したり、nn.Module
をイテレートして特定のパターンを探したりするなどです。これは非常に低レベルで、汎用性に欠けますが、特定のニッチな問題には有効な場合があります。
torch.fx.Tracer.create_args_for_root()
自体をプログラミングで代替するというよりは、「FXトレースの開始時に入力を準備し、それによってグラフの形状を制御する」 という目的においては、torch.fx.symbolic_trace()
の concrete_args
引数が最も直接的かつ一般的に使用される代替手段です。