【PyTorch FX】create_node()徹底解説!エラーと対処法も網羅
FXは、PyTorchモデルを操作可能なグラフ(有向非巡回グラフ:DAG)として表現し、そのグラフを変換・最適化・コード生成するためのツールキットです。このグラフは、モデルの入力、演算(関数呼び出し、モジュール呼び出し)、および出力などを表す「ノード」と、それらのノード間のデータフローを表す「エッジ」で構成されます。
torch.fx.Graph.create_node()
メソッドは、このグラフに新しい操作や値を表すノードを手動で追加するために使用されます。
create_node()
の主な引数と、それが何をするのか
create_node()
メソッドは、主に以下の引数を受け取ります。
-
op
(Operation Type):- このノードがどのような種類の操作を表すかを示します。FXグラフには、主に以下の5種類の操作があります。
'placeholder'
: 関数の入力(引数)を表すノード。'call_function'
: Pythonの自由関数(例:torch.add
,F.relu
など)の呼び出しを表すノード。'call_method'
: オブジェクトのメソッド呼び出し(例:x.view()
,x.mean()
など)を表すノード。'call_module'
:torch.nn.Module
のインスタンス(例:self.linear(x)
,self.conv1(x)
など)のforward
メソッド呼び出しを表すノード。'get_attr'
: モデルの属性(例:self.param
)を取得するノード。'output'
: グラフの最終的な出力を表すノード。
- このノードがどのような種類の操作を表すかを示します。FXグラフには、主に以下の5種類の操作があります。
-
target
(Target of Operation):op
の種類に応じて、その操作の対象となるものを指定します。'placeholder'
の場合:入力引数の名前(文字列)。'call_function'
の場合:呼び出される関数オブジェクト(例:torch.add
)。'call_method'
の場合:呼び出されるメソッドの名前(文字列、例:'view'
)。'call_module'
の場合:呼び出されるモジュールの完全修飾名(文字列、例:'linear_layer.0'
)。'get_attr'
の場合:取得する属性の完全修飾名(文字列、例:'param'
)。'output'
の場合:出力される値。
-
args
(Positional Arguments):- その操作に渡される位置引数(タプル)です。これらの引数は、通常、既にグラフ内に存在する他のノード(Nodeオブジェクト)への参照、または定数値になります。
-
kwargs
(Keyword Arguments):- その操作に渡されるキーワード引数(辞書)です。
args
と同様に、他のノードへの参照や定数値を含みます。
- その操作に渡されるキーワード引数(辞書)です。
-
name
(Node Name):- グラフ内で一意なノードの名前(文字列)。省略すると自動的に生成されます。
create_node()
の目的と利用シーン
create_node()
は、主に以下のような目的で使われます。
- コード生成: FXグラフは、最終的にPythonコードとして再生成されることがありますが、そのコードの元となる構造を
create_node()
で定義します。 - グラフの変換・編集: 既存のFXグラフをロードした後、そのグラフに新しい演算を追加したり、既存の演算を変更・削除したりする際に、手動で
create_node()
を呼び出して新しいノードを挿入します。これは、モデルの最適化(例: レイヤーマージ、不必要な演算の削除)や、カスタムな計算グラフの作成に利用されます。 - グラフの構築: PyTorchのモデルをトレース(
torch.fx.symbolic_trace
)する際に、内部的に各操作がこのcreate_node()
によってグラフに追加されていきます。
例えば、x + y
という操作をグラフに追加したい場合、概念的には以下のようにノードを作成します。
import torch
import torch.fx as fx
# グラフを初期化(通常はsymbolic_traceで取得するが、ここでは手動で作成の例)
graph = fx.Graph()
# 入力ノードを作成
x_node = graph.placeholder('x')
y_node = graph.placeholder('y')
# 足し算のノードを作成
# 'call_function' は関数呼び出しを示し、target は torch.add 関数
# args は引数として x_node と y_node を渡す
add_node = graph.create_node('call_function', torch.add, args=(x_node, y_node))
# 出力ノードを作成
output_node = graph.output(add_node)
# グラフを表示してみる
graph.print_tabular()
この例では、x
と y
というプレースホルダーノード、そしてそれらを加算するtorch.add
の関数呼び出しノードをcreate_node()
を使ってグラフに追加しています。
target と op の不整合
エラーの例: RuntimeError: "target" argument must be a function for op 'call_function'
のようなエラー
原因: op
引数で指定した操作の種類 ('call_function'
, 'call_method'
, 'call_module'
など) と、target
引数で渡す値の型が一致しない場合に発生します。
'call_module'
の場合、target
はモジュールのパス (例:'linear1'
,'sub_module.conv'
) の文字列である必要があります。'call_method'
の場合、target
は文字列 (例:'view'
,'sum'
) で、そのメソッド名を示します。'call_function'
の場合、target
は callable な関数オブジェクト (例:torch.add
,F.relu
) である必要があります。
トラブルシューティング:
torch.fx.Node
のドキュメントを参照し、各op
タイプに対応するtarget
の要件を再確認します。op
とtarget
の組み合わせが正しいかを確認します。特に、関数呼び出しなのに文字列を渡してしまったり、モジュール呼び出しなのに関数オブジェクトを渡してしまったりするミスが多いです。
args / kwargs の入力ノードの不適切さ
エラーの例:
- 特に、既存のグラフから値を受け取る場合は、その値がノードオブジェクトとして渡されているかを確認します。
args
やkwargs
に渡す値が、適切なfx.Node
オブジェクトであるか、または定数値であるかを確認します。RuntimeError: "Node" object expected, but received "Tensor"
(グラフ内のノードを参照すべき場所で生のテンソルを渡してしまった場合) 原因:create_node()
のargs
やkwargs
には、通常、既にグラフ内に存在する他のノード (fx.Node
オブジェクト) を渡す必要があります。数値や文字列などの定数も渡せますが、データの流れを表す場合はノード間の接続が必要です。 トラブルシューティング:TypeError: 'int' object is not iterable
(引数がタプルやリストであるべきなのに整数を渡してしまった場合)
モジュールノード (call_module) の解決不足
エラーの例: AttributeError: 'MyModule' object has no attribute '_param_constant0'
のようなエラー、または RuntimeError: Attempted to insert a call_module Node with a target that does not exist on the owning GraphModule.
原因: op='call_module'
でノードを作成する場合、そのtarget
は、グラフが属する GraphModule
インスタンス内に実際に存在するサブモジュールへのパス(文字列)である必要があります。手動でグラフを構築している場合、このサブモジュールがGraphModule
にadd_module
されていないとエラーになります。
トラブルシューティング:
symbolic_trace
を使用せずに手動でグラフを構築している場合、サブモジュールの名前付け規則 (例:self.sub_module_name
) に従っているかを確認します。GraphModule
を作成する際に、必要なサブモジュールを適切にadd_module
しているかを確認します。call_module
で参照するパスが、GraphModule
のnamed_modules()
でリストアップされるパスと正確に一致するかを確認します。
グラフの再コンパイル忘れ (GraphModule.recompile())
エラーの例: グラフを操作したのに、モデルの振る舞いが変わらない、または不正な出力が得られる。
原因: torch.fx.Graph
オブジェクトを直接操作してノードを追加・削除・変更しても、その変更はすぐに GraphModule
の forward
メソッドに反映されるわけではありません。GraphModule
は、内部的にグラフに基づいてforward
メソッドを生成しています。
トラブルシューティング:
- グラフを変更した後、必ず
GraphModule
インスタンスのrecompile()
メソッドを呼び出してください。これにより、変更されたグラフに基づいてforward
メソッドが再生成され、新しいグラフ構造が有効になります。
symbolic_trace と手動作成の混同
エラーの例: TraceError: symbolically traced variables cannot be used as inputs to control flow
や、期待通りのトレースができない。
原因: torch.fx.symbolic_trace
を使用して既存のPyTorchモデルからグラフを生成する場合と、Graph()
を直接インスタンス化して手動でノードを追加する場合では、FXが内部的に処理するメカニズムが異なります。特に、symbolic_trace
は、Pythonの制御フロー(if/else
、for
ループなど)を完全にキャプチャできない場合があります。手動でノードを作成する際は、このような制約を考慮する必要があります。
トラブルシューティング:
- 手動でグラフを構築する際は、それぞれのノードが具体的な演算や値に直接対応するように構成し、Pythonの抽象的な制御フローに依存しないようにします。
- 複雑な制御フローを持つモデルをFXで処理したい場合は、
torch.compile
のような上位のAPIの使用を検討します。torch.compile
は内部的にFXグラフを利用しますが、より高度な制御フロー処理をサポートします。
エラーの例: グラフ内のノードを削除または置き換えた後、そのノードを参照していた他のノードが壊れる。
原因: グラフ内のノードは相互に参照しあっています。あるノードを削除したり、その出力が別のノードに置き換えられたりした場合、そのノードのusers
(そのノードの出力を利用しているノード)が新しい出力を参照するように更新されないと、グラフが不整合な状態になります。
トラブルシューティング:
Graph.eliminate_dead_code()
を呼び出すことで、参照されなくなった(デッドな)ノードをクリーンアップできます。- ノードを削除または置き換える際には、
Node.replace_all_uses_with(new_node)
やNode.replace_input_with(old_input, new_input)
といったメソッドを適切に使用して、グラフの整合性を維持するようにしてください。
ゼロからグラフを構築する例
この例では、x + y
というシンプルな計算を表現するグラフをゼロから構築します。
import torch
import torch.fx as fx
import operator # torch.add の代わりに Python の operator.add を使うこともできます
# 1. 新しい空のGraphオブジェクトを作成します
graph = fx.Graph()
# 2. プレースホルダーノード (入力) を作成します
# 'placeholder' は、グラフの入力引数を表します。targetはその引数の名前です。
x_node = graph.placeholder('x')
y_node = graph.placeholder('y')
# 3. 'call_function' ノードを作成します
# これは Python の関数呼び出しを表します。
# target は呼び出される関数オブジェクト(ここでは torch.add)です。
# args は位置引数のタプルで、既存のノード(x_node, y_node)を渡します。
add_node = graph.create_node('call_function', torch.add, args=(x_node, y_node))
# 4. 出力ノードを作成します
# 'output' はグラフの最終的な出力を表します。
# target は出力される値のノードです。
output_node = graph.output(add_node)
# 5. GraphModuleを作成し、グラフを実行可能にします
# GraphModuleはnn.Moduleのサブクラスであり、このGraphをforwardメソッドとして持ちます。
# 最初の引数はGraphModuleが属するモジュール(ここではダミーのnn.Module)ですが、
# 通常はsymbolic_traceが生成します。手動作成ではダミーでOKです。
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
# このモジュールには実際のサブモジュールは不要ですが、GraphModuleのコンストラクタは
# 最初の引数として親モジュールを期待するため、ダミーとして渡します。
# または、GraphModule(None, graph) とすることもできます。
pass
gm = fx.GraphModule(MyModule(), graph)
# 6. グラフの構造を表示します
# tabular形式でノードの情報を確認できます。
print("--- 構築したグラフ ---")
gm.graph.print_tabular()
# 7. 実行して結果を確認します
input_x = torch.tensor(10.0)
input_y = torch.tensor(20.0)
result = gm(input_x, input_y)
print(f"\n実行結果: {result}") # 期待される出力: 30.0
既存のグラフにノードを追加・挿入する例
既存のPyTorchモデルをトレースして得られたFXグラフを編集する例です。ここでは、線形層の後にReLUを追加してみます。
import torch
import torch.nn as nn
import torch.fx as fx
# 元のモデルを定義
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
model = SimpleModel()
# symbolic_traceを使ってモデルからグラフを取得
traced_model = fx.symbolic_trace(model)
graph = traced_model.graph
print("--- 元のグラフ ---")
graph.print_tabular()
# 新しいReLUノードを挿入する場所を見つけます
# linear1 の出力を探し、その後に ReLU を挿入したい
for node in graph.nodes:
if node.name == 'linear1': # linear1 ノードを見つける
linear1_output_node = node
break
# linear1 の直後に ReLU を挿入するための準備
# graph.inserting_after() コンテキストマネージャを使うと、指定したノードの直後にノードが挿入されます
with graph.inserting_after(linear1_output_node):
# 'call_function' ノードとして torch.relu を作成
# args には linear1_output_node の出力を渡します
relu_node = graph.create_node('call_function', torch.relu, args=(linear1_output_node,))
# linear2 ノードの入力を更新します
# 元々 linear2 は linear1_output_node を入力としていましたが、
# これを relu_node の出力に変更します。
# find_users は、このノードの出力を利用しているノードを見つけるのに便利です。
for user_node in linear1_output_node.users:
# linear2 が linear1 の出力を利用している場合を想定
# 通常は graph.nodes を順に辿って見つける方が確実
if user_node.name == 'linear2':
# linear2 の入力リストをイテレートし、linear1_output_node を relu_node に置き換えます
# Node.replace_input_with() を使うのが最も安全な方法です
user_node.replace_input_with(linear1_output_node, relu_node)
# グラフの変更をGraphModuleに反映させるためにrecompileします
traced_model.recompile()
print("\n--- ReLUを追加した後のグラフ ---")
traced_model.graph.print_tabular()
# 変更後のモデルを実行して確認
input_tensor = torch.randn(1, 10)
output_before_relu = model(input_tensor) # 元のモデルの出力
output_after_relu = traced_model(input_tensor) # 変更後のモデルの出力
print(f"\n元のモデルの出力 (linear1 の後に ReLU なし): {output_before_relu}")
print(f"変更後のモデルの出力 (linear1 の後に ReLU あり): {output_after_relu}")
# ReLUの効果を簡易的に確認 (値が負にならないこと)
print(f"元のモデルの linear1 出力後 (生のデータ): {model.linear1(input_tensor)}")
print(f"変更後のモデルの linear1 出力後 (ReLU適用後): {traced_model.graph.find_node('relu').args[0]}") # この行はデバッグ用。直接Graphを評価する方法ではない。
解説:
- 最後に
traced_model.recompile()
を呼び出すことで、グラフの変更がGraphModule
のforward
メソッドに反映されます。 user_node.replace_input_with(old_input, new_input)
を使って、linear2
ノードがlinear1
の出力ではなく、新しく作成したrelu_node
の出力を参照するように変更します。これにより、データフローが正しく更新されます。graph.create_node('call_function', torch.relu, args=(linear1_output_node,))
で新しいReLUノードを作成します。引数にはlinear1_output_node
を指定し、ReLUがその出力を受け取るようにします。graph.inserting_after(linear1_output_node)
を使用して、linear1_output_node
の直後に新しいノードを挿入するコンテキストを設定します。- グラフ内の特定のノード (
linear1
) を見つけます。 symbolic_trace(model)
で既存のモデルからグラフを生成します。
カスタムモジュールをグラフに追加する例です。
import torch
import torch.nn as nn
import torch.fx as fx
class CustomActivation(nn.Module):
def forward(self, x):
return torch.sigmoid(x) * 2 # カスタムのアクティベーション関数
class MyNetwork(nn.Module):
def __init__(self):
super().__init__()
self.linear_in = nn.Linear(10, 20)
self.custom_act = CustomActivation() # カスタムモジュール
self.linear_out = nn.Linear(20, 1)
def forward(self, x):
x = self.linear_in(x)
x = self.custom_act(x) # ここがcall_moduleノードになる
x = self.linear_out(x)
return x
model = MyNetwork()
traced_model = fx.symbolic_trace(model)
print("--- モジュールを含む元のグラフ ---")
traced_model.graph.print_tabular()
# ここでは既存のグラフからモジュールのノードを直接作成するのではなく、
# 'call_module' の概念を理解するための補足的な例とします。
# 実際には symbolic_trace が自動的に生成します。
# もし手動で 'custom_act' のようなノードを作成したい場合 (通常はトレースされる)
# まず、グラフに追加したいモジュールが GraphModule の属性として存在することを確認します。
# traced_model に 'custom_act' という属性があることを確認
assert hasattr(traced_model, 'custom_act')
# linear_in の出力を取得(これは既存のグラフから取得)
linear_in_node = None
for node in traced_model.graph.nodes:
if node.name == 'linear_in':
linear_in_node = node
break
if linear_in_node:
# with traced_model.graph.inserting_after(linear_in_node): # 挿入ポイントを設定
# # 'call_module' ノードを作成
# # target は GraphModule 内のサブモジュールへのパス(文字列)
# # args はモジュールに渡される入力ノード
# # このノードは実際には symbolic_trace によって生成されるものと同じ
# custom_act_node = traced_model.graph.create_node(
# 'call_module', 'custom_act', args=(linear_in_node,)
# )
# print("\n--- 手動で作成されたカスタムアクティベーションノード(理論上) ---")
# print(custom_act_node)
pass
else:
print("linear_in_node が見つかりませんでした。")
# この例は、`create_node()` を使って手動で GraphModule のサブモジュールへの
# 呼び出しノードを定義する方法を示しています。
# ただし、ほとんどの場合、`symbolic_trace` がこれらのノードを自動的に生成するため、
# 開発者が直接 `create_node()` で `call_module` ノードを生成する必要はあまりありません。
# 主な利用シーンは、グラフ変換ツールなどで、既存のノードを別のモジュール呼び出しに置き換える場合などです。
解説:
- この例では、
symbolic_trace
が自動的にcall_module
ノードを作成することを示しています。手動で作成する場合の概念的なコードはコメントアウトされていますが、実際に手動で追加することは稀です。 call_module
を使用する場合、target
はGraphModule
インスタンス内のサブモジュールの名前(文字列)である必要があります。このサブモジュールは、GraphModule
が構築された際に、元のモデルからコピーされてくるものです。
torch.fx.symbolic_trace() を使う (最も一般的)
説明:
これは、既存の torch.nn.Module
インスタンスからFXグラフを生成する最も一般的で推奨される方法です。symbolic_trace()
は、モデルの forward
メソッドをシンボリックに実行し、その過程で発生するすべてのPyTorch操作やモジュール呼び出しを自動的にFXグラフのノードとしてキャプチャします。
create_node()
との関連:
symbolic_trace()
の内部では、モデルの各操作(例: torch.add
の呼び出し、nn.Linear
の呼び出しなど)が検出されるたびに、対応する create_node()
の呼び出しが自動的に行われ、グラフにノードが追加されます。開発者が明示的にノードタイプ ('call_function'
, 'call_module'
など) や引数を指定する必要はありません。
利点:
- PyTorchの慣習に沿う: 通常のPyTorchモデル定義から直接FXグラフを作成できます。
- 網羅性: モデルの大部分の操作(テンソル操作、モジュール呼び出し)を自動的にキャプチャします。
- 容易さ: 既存のモデルをFXグラフに変換する最も簡単な方法です。
欠点:
- 部分的なグラフの構築が難しい: 特定の操作だけを含むグラフをゼロから構築するのには向きません。
- 制御フローの制限: Pythonの動的な制御フロー(例: テンソルの値に依存する
if/else
)は、トレース時に静的に決定されるため、完全にはキャプチャできません。
コード例:
import torch
import torch.nn as nn
import torch.fx as fx
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.pool(x)
return x
model = MyModel()
traced_model = fx.symbolic_trace(model)
print("--- symbolic_trace で生成されたグラフ ---")
traced_model.graph.print_tabular()
# このtabular出力を見ると、'call_module' や 'call_function' ノードが自動的に生成されていることがわかります。
Node.replace_all_uses_with() や Node.replace_input_with() などのノード操作メソッド
説明:
create_node()
が「ノードを追加する」ためのAPIであるのに対し、FXグラフ内の既存のノードを「変更する」ための高レベルなメソッド群があります。これらは、グラフの変換や最適化を行う際に非常に役立ちます。
Node.replace_input_with(old_input, new_input)
:- このノードが受け取っている特定の入力
old_input
を、別の入力new_input
に置き換えます。
- このノードが受け取っている特定の入力
Node.replace_all_uses_with(new_node)
:- このノードの出力を利用しているすべてのノードに対し、参照先を
new_node
の出力に置き換えます。実質的に、このノードの機能をnew_node
に委譲し、古いノードはデッドコードになることが多いです。
- このノードの出力を利用しているすべてのノードに対し、参照先を
create_node()
との関連:
これらのメソッドは、既存のグラフを操作する際に、create_node()
で新しいノードを作成した後、その新しいノードを既存のグラフのデータフローに組み込むために組み合わせて使用されることが多いです。手動で args
や kwargs
を変更するよりも、より安全で意図が明確です。
利点:
- 高レベルな抽象化: ノードの入力を直接操作するよりも、意味的に分かりやすい操作です。
- 安全なグラフ変更: グラフの整合性を保ちながらノード間の接続を変更できます。
コード例:
前の例と同じ traced_model
を使用し、relu
の後に sigmoid
を追加し、pool
の入力を sigmoid
の出力に変更する例。
import torch
import torch.nn as nn
import torch.fx as fx
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, 3)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.pool(x)
return x
model = MyModel()
traced_model = fx.symbolic_trace(model)
print("--- 変更前のグラフ ---")
traced_model.graph.print_tabular()
# relu ノードを見つける
relu_node = None
for node in traced_model.graph.nodes:
if node.op == 'call_module' and node.target == 'relu':
relu_node = node
break
if relu_node:
# relu ノードの直後に新しい sigmoid ノードを挿入
with traced_model.graph.inserting_after(relu_node):
sigmoid_node = traced_model.graph.create_node(
'call_function', torch.sigmoid, args=(relu_node,)
)
# relu の出力を利用しているすべてのノード (この場合 pool) が、
# sigmoid_node の出力を利用するように置き換える
# この例では、relu_node.users を見て、直接 replace_input_with を呼び出す
# もしくは、relu_node.replace_all_uses_with(sigmoid_node) を使用することもできます。
for user in relu_node.users:
if user.target == 'pool': # 'pool' モジュール呼び出しがreluの出力を使っている
user.replace_input_with(relu_node, sigmoid_node)
break
# グラフの変更を反映させる
traced_model.recompile()
print("\n--- sigmoid を追加し、入力を置き換えた後のグラフ ---")
traced_model.graph.print_tabular()
Graph.inserting_before() / Graph.inserting_after() コンテキストマネージャ
説明:
これらのコンテキストマネージャは、新しいノードをグラフ内の特定の位置に挿入する際の利便性を提供します。これらのコンテキスト内で graph.create_node()
を呼び出すと、ノードは自動的に指定された位置に挿入されます。
create_node()
との関連:
create_node()
と組み合わせて使用することで、ノードの挿入位置を明示的に制御できます。これにより、手動でノードをリストに追加したり、順序を調整したりする手間が省けます。
利点:
- 順序の保証: グラフ内でのノードの評価順序を簡単に制御できます。
- ノード挿入の簡略化: コードがより読みやすく、ノード挿入の意図が明確になります。
コード例: 上記「既存のグラフにノードを追加・挿入する例」で既にこれを使用しています。
# relu_node の直後に新しい sigmoid ノードを挿入する部分
with traced_model.graph.inserting_after(relu_node):
sigmoid_node = traced_model.graph.create_node(
'call_function', torch.sigmoid, args=(relu_node,)
)
説明:
FXには、グラフを変換するための「パス」という概念があります。これは、GraphModule
を入力として受け取り、変換された新しい GraphModule
を返す関数やクラスです。これらは、より大規模なグラフ変換(例: 量子化、融合など)をカプセル化するために使われます。
create_node()
との関連:
FXパスの内部では、create_node()
が新しいノードを作成したり、既存のノードを操作したりするために使用されることがあります。しかし、パスのインターフェース自体は高レベルであり、パスのユーザーは通常 create_node()
を直接意識しません。
利点:
- 共通の最適化: PyTorchが提供する標準的な最適化(例: TorchDynamo/TorchInductorのバックエンド)で利用されます。
コード例 (概念的):
# 例えば、すべてのReLUをLeakyReLUに置き換えるパス
def relu_to_leaky_relu_pass(gm: fx.GraphModule) -> fx.GraphModule:
new_graph = fx.Graph()
env = {} # 古いノードと新しいノードのマッピング
for node in gm.graph.nodes:
if node.op == 'call_module' and isinstance(gm.get_submodule(node.target), nn.ReLU):
# 新しいLeakyReLUモジュールを作成して追加 (create_node ではないが、概念的に新しいモジュールノード)
# 実際には GraphModule に新しいモジュールを追加する必要がある
leaky_relu_module = nn.LeakyReLU()
new_target = node.target + '_leaky' # 新しいモジュール名
setattr(gm, new_target, leaky_relu_module) # GraphModuleに新しいサブモジュールを追加
# 新しい LeakyReLU ノードを作成 (create_node を内部で使用する)
new_node = new_graph.create_node('call_module', new_target, args=(env[node.args[0]],))
env[node] = new_node
else:
# 他のノードはそのままコピー (または新しいグラフに再作成)
# ここでも create_node が内部で使われることが多い
new_node = new_graph.create_node(node.op, node.target, node.args, node.kwargs, node.name)
env[node] = new_node
return fx.GraphModule(gm, new_graph)
# 利用例
# transformed_model = relu_to_leaky_relu_pass(traced_model)
# transformed_model.recompile()