call_moduleだけじゃない!PyTorch FXの多様なモデル変換手法
- Symbolic Tracer:
nn.Module
の実行をシンボリックにトレースし、モデルの演算をキャプチャします。 - Intermediate Representation (IR): トレースされた演算をグラフとして表現します。これは
torch.fx.Graph
というデータ構造で、ノードのリストから構成されます。 - Python Code Generation: IRからPythonコードを生成し、新しい
nn.Module
を作成します。
call_module()
の役割
torch.fx.Transformer
は、この中間表現であるGraph
を操作し、変換を行うための基底クラスです。Transformer
クラスには、グラフ内のさまざまな種類のノードに対応するメソッドが用意されており、それらをオーバーライドすることで、特定の演算に対するカスタムな変換ロジックを実装できます。
その中の1つがcall_module()
メソッドです。
call_module(self, target, args, kwargs)
は、torch.fx.Graph
内のノードが別のtorch.nn.Module
インスタンスを呼び出す操作を表す場合に呼び出されます。
具体的には、あるnn.Module
のforward
メソッド内で、別のサブモジュール(例: self.linear(x)
のような呼び出し)が実行されると、FXのトレース時にこの操作がcall_module
ノードとして記録されます。torch.fx.Transformer
を継承したカスタムの変換クラスでcall_module
をオーバーライドすることで、以下のようなことができます。
- モジュールの置き換え: あるモジュールを別のモジュール(例えば、最適化されたバージョンや量子化されたバージョン)に置き換えることができます。
- モジュールの削除: 特定のサブモジュールの呼び出しを完全に削除し、その入力をそのまま返すように変更することで、モジュールをグラフから取り除くことができます。
- 特定のサブモジュールの挙動を変更する: 例えば、ある特定の
nn.Module
が呼び出されたときに、そのモジュールの代わりに別のモジュールを呼び出すように変更したり、そのモジュールの入力や出力を加工したりできます。
例えば、モデル内のすべてのtorch.nn.Dropout
モジュールを削除したい場合、以下のようにtorch.fx.Transformer
を継承し、call_module
メソッドをオーバーライドします。
import torch
import torch.nn as nn
import torch.fx as fx
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
self.dropout = nn.Dropout(0.5)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.dropout(x) # このdropout呼び出しがcall_moduleノードになる
x = self.relu(x)
return x
class DropoutRemover(fx.Transformer):
def call_module(self, target, args, kwargs):
# targetは呼び出されているモジュールの名前(文字列、例: 'dropout')
# self.submodules[target]で実際のモジュールインスタンスにアクセスできる
if isinstance(self.submodules[target], nn.Dropout):
# Dropoutモジュールの場合、その入力をそのまま返すことで削除する
assert len(args) == 1 # 通常Dropoutは単一の入力
return args[0]
else:
# それ以外のモジュールの場合は、デフォルトのTransformerの挙動に従う
return super().call_module(target, args, kwargs)
# モデルを定義してトレース
model = MyModule()
traced_model = fx.symbolic_trace(model)
# Transformerを適用してDropoutを削除
transformed_model = DropoutRemover(traced_model).transform()
print("元のグラフ:")
traced_model.graph.print_tabular()
print("\n変換後のグラフ:")
transformed_model.graph.print_tabular()
この例では、DropoutRemover
クラスがcall_module
をオーバーライドし、呼び出されたモジュールがnn.Dropout
のインスタンスである場合に、そのモジュールの入力をそのまま返しています。これにより、結果として生成されるグラフからDropoutモジュールが削除され、推論の高速化やメモリ削減に貢献できます。
TraceError: symbolically traced variables cannot be used as inputs to control flow (制御フロー内のエラー)
原因
FXのSymbolic Tracerは、Pythonの動的な制御フロー(if
文、for
ループなど)を直接トレースすることは苦手です。特に、テンソルの値に依存するような制御フロー(例: if x.shape[0] > 0:
)があると、トレースが失敗したり、意図しない形で定数化されたりすることがあります。call_module()
の内部で、このようなトレースできない制御フローがあると、エラーが発生する可能性があります。
トラブルシューティング
- 動的なシェイプへの注意
入力テンソルのシェイプが動的に変化する場合、トレース時に与えるダミー入力のシェイプが実際の使用状況と一致しないと、問題が発生することがあります。FXはトレース時の入力シェイプに特化される傾向があるため、異なるシェイプの入力で実行するとエラーになることがあります。 - torch.fx.wrapの使用
特定の関数やモジュールの内部でトレースできない操作がある場合、torch.fx.wrap()
でその部分をラップすることで、FXがその関数/モジュールを「ブラックボックス」として扱い、内部のトレースをスキップさせることができます。ただし、これにより内部の最適化機会が失われる可能性があります。 - 制御フローの除去または書き換え
可能な限り、テンソルの値に依存する制御フローをモデルから除去するか、FXがトレースしやすい形(例:torch.where
などのテンソル演算)に書き換えます。
AttributeError: 'Graph' object has no attribute 'some_attribute' (間違ったノードや属性へのアクセス)
原因
call_module()
をオーバーライドする際、target
、args
、kwargs
といった引数を適切に扱わないと、期待しないエラーが発生します。特に、target
はモジュールの階層的な名前(例: 'linear'
や'encoder.layer.0.attn'
)であり、直接モジュールインスタンスではありません。
トラブルシューティング
- グラフの構造理解
traced_model.graph.print_tabular()
でグラフの構造を確認し、call_module
ノードがどのように表現されているかを理解することが重要です。どのモジュールがどのtarget
名で呼び出されているかを正確に把握します。 - argsとkwargsの確認
呼び出されるモジュールが期待する引数の数や型が、args
とkwargs
で提供されているものと一致しているか確認します。FXのノードのargs
はタプル、kwargs
は辞書です。 - self.submodules[target]の使用
call_module()
内で実際のモジュールインスタンスにアクセスしたい場合は、self.submodules[target]
を使用する必要があります。target
はあくまで文字列としての名前です。
RuntimeError: Expected all tensors to be on the same device, but found at least two devices (デバイスの不一致)
原因
call_module()
内で新しいテンソルを作成したり、既存のテンソルを操作したりする際に、デバイスの扱いを誤ると、このエラーが発生します。例えば、GPU上で動作しているモデルのモジュール呼び出しに対して、CPU上で新しいテンソルを作成して入力として渡そうとすると、エラーになります。
トラブルシューティング
- argsのデバイス
args
に含まれるテンソルは、多くの場合、呼び出されるモジュールと同じデバイスに既に存在します。変換でこれらを操作する際は、デバイスが変更されないように注意が必要です。 - デバイスの一貫性
新しく作成するテンソルや、変換によって挿入するモジュールが、元のモデルと同じデバイス上に存在するようにします。入力テンソルの.device
属性を確認し、それに応じて.to(device)
や.cuda()
を使用します。
変換後のモデルが期待通りに動作しない/結果が異なる
原因
これはエラーメッセージとして現れるわけではありませんが、最もよく遭遇する問題の一つです。call_module()
のオーバーライドが、モデルの動作に意図しない副作用をもたらしている可能性があります。
トラブルシューティング
- 元の挙動の呼び出し
変換のロジックが複雑になる場合、デフォルトのTransformer
の挙動を呼び出すsuper().call_module(target, args, kwargs)
を適切に使用し、変更が必要な部分のみをカスタム実装するようにします。 - 段階的な変換
一度に多くの変更を加えるのではなく、小さな変更を加えてはテストするというサイクルを繰り返すことで、問題の原因を特定しやすくなります。 - デバッグ出力
変換の前と後で、モデルの入力と出力の間にprint
文やデバッグ用のロギングを挿入し、具体的な値やシェイプの変化を追跡します。 - 副作用の管理
call_module()
内で、モジュールの状態(例:self.training
、バッファ、パラメータ)に影響を与えるような操作を行っていないか確認します。FXのトレースは、基本的に純粋関数的な動作を期待します。 - 入出力の一貫性
call_module()
が返す値の型、シェイプ、デバイスが、元のモジュールが返すものと完全に互換性があることを確認します。例えば、元のモジュールがTensor
を返していたのに、変換後にTuple[Tensor]
を返すと、後続の演算が失敗することがあります。
原因
これはcall_module()
自体に直接関係するエラーではありませんが、FXを使用する環境設定の問題で発生することがあります。PyTorch FXはPyTorch本体に含まれていますが、特定のバージョン要件がある場合があります。
- 環境のクリーンアップ
仮想環境(condaやvenv)を新しく作成し、必要なパッケージのみをインストールして試すことで、依存関係の競合が原因でないことを確認できます。 - PyTorchのバージョン確認
使用しているPyTorchのバージョンがFXをサポートしているか、および特定の機能(例: 新しいトレーサーオプション)が必要な場合は、そのバージョンを満たしているか確認します。通常、最新の安定版を使用していれば問題ありません。
例1: 特定のモジュールを別のモジュールに置き換える
この例では、モデル内のすべてのnn.ReLU
モジュールをnn.LeakyReLU
に置き換えます。
import torch
import torch.nn as nn
import torch.fx as fx
# 1. 元のモデルの定義
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.relu = nn.ReLU() # これをLeakyReLUに置き換える
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# 2. Transformerの定義: call_moduleをオーバーライド
class ReLUToLeakyReLUTransformer(fx.Transformer):
def call_module(self, target, args, kwargs):
# targetは呼び出されているモジュールの名前(文字列)
# self.submodules[target]で実際のモジュールインスタンスにアクセス
# 呼び出されたモジュールがnn.ReLUのインスタンスであるかチェック
if isinstance(self.submodules[target], nn.ReLU):
print(f"DEBUG: Replacing ReLU at target '{target}' with LeakyReLU.")
# 新しいLeakyReLUモジュールを作成し、サブモジュールとして登録
# NOTE: モジュールを登録しないと、変換後のグラフで再利用されない可能性があります。
# しかし、Transformerの transform() メソッドが最終的に新しいModuleを生成する際に、
# グラフ内のノードに紐づく新しいモジュールが適切にインスタンス化されます。
# ここでは単にその場で新しいインスタンスを返すことで、グラフがそのモジュールを
# 参照するようにします。
# 重要: Transformerの transform() メソッドが最終的なModuleを作成する際、
# ここで返されるモジュールインスタンスそのものではなく、
# それに対応する新しいノードがグラフに挿入されます。
# そのノードは新しいLeakyReLUインスタンスを指すことになります。
# 新しいLeakyReLUインスタンスを作成し、その呼び出しを返します
# argsは元のReLUの入力テンソルです。
return nn.LeakyReLU()(*args, **kwargs) # LeakyReLUを呼び出す
else:
# それ以外のモジュールの場合は、デフォルトのTransformerの挙動に従う
return super().call_module(target, args, kwargs)
# 3. モデルのトレースと変換
model = MyModel()
traced_model = fx.symbolic_trace(model)
print("--- 変換前のグラフ ---")
traced_model.graph.print_tabular()
# Transformerを適用
transformed_model = ReLUToLeakyReLUTransformer(traced_model).transform()
print("\n--- 変換後のグラフ ---")
transformed_model.graph.print_tabular()
# 4. 動作確認 (オプション)
input_data = torch.randn(1, 10)
output_original = model(input_data)
output_transformed = transformed_model(input_data)
# ReLUは負の値を0にするが、LeakyReLUは少しだけ負の値を残すので、結果は異なるはず
print(f"\nOriginal Model Output (first 5 values): {output_original[0, :5]}")
print(f"Transformed Model Output (first 5 values): {output_transformed[0, :5]}")
# グラフを見ると、'relu' ノードが 'leaky_relu' のようなノードに置き換わっていることがわかる
解説
MyModel
を定義し、nn.ReLU
を含めます。ReLUToLeakyReLUTransformer
はfx.Transformer
を継承します。call_module
メソッドをオーバーライドし、呼び出されたモジュールがnn.ReLU
のインスタンスであるかをisinstance(self.submodules[target], nn.ReLU)
でチェックします。- もし
nn.ReLU
であれば、nn.LeakyReLU()(*args, **kwargs)
を返します。これは、元のnn.ReLU
が受け取っていた入力args
とkwargs
を使って、新しいnn.LeakyReLU
インスタンスを呼び出すことを意味します。FXはこれを新しいノードとしてグラフに記録します。 - それ以外のモジュール呼び出しは、
super().call_module(target, args, kwargs)
を呼び出して、親クラスのデフォルトの動作(元のモジュールを呼び出す)に従います。
例2: 特定のモジュールを削除する (パススルー)
この例では、モデル内のすべてのnn.Dropout
モジュールを効果的に削除します。
import torch
import torch.nn as nn
import torch.fx as fx
# 1. 元のモデルの定義
class MyDropoutModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 10, 3)
self.dropout1 = nn.Dropout(0.2) # これを削除したい
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(10, 20, 3)
self.dropout2 = nn.Dropout(0.5) # これも削除したい
self.pool = nn.MaxPool2d(2)
self.fc = nn.Linear(20 * 5 * 5, 10)
def forward(self, x):
x = self.conv1(x)
x = self.dropout1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.dropout2(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# 2. Transformerの定義: call_moduleをオーバーライド
class DropoutRemover(fx.Transformer):
def call_module(self, target, args, kwargs):
# 呼び出されたモジュールがnn.Dropoutのインスタンスであるかチェック
if isinstance(self.submodules[target], nn.Dropout):
print(f"DEBUG: Removing Dropout at target '{target}'.")
# Dropoutモジュールの場合、その入力をそのまま返す
# Dropoutは通常、単一のテンソル入力を受け取る
assert len(args) == 1, "Dropout is expected to have a single input."
assert not kwargs, "Dropout is not expected to have kwargs."
return args[0] # Dropoutの入力をそのまま出力として返す
else:
# それ以外のモジュールの場合は、デフォルトのTransformerの挙動に従う
return super().call_module(target, args, kwargs)
# 3. モデルのトレースと変換
model = MyDropoutModel()
traced_model = fx.symbolic_trace(model)
print("--- 変換前のグラフ ---")
traced_model.graph.print_tabular()
# Transformerを適用
transformed_model = DropoutRemover(traced_model).transform()
print("\n--- 変換後のグラフ ---")
transformed_model.graph.print_tabular()
# 4. 動作確認 (オプション)
# Dropoutは訓練時のみアクティブになるので、evalモードで比較
model.eval()
transformed_model.eval() # 変換後のモデルはDropoutが削除されているので、通常はevalモードでも同じ
input_data = torch.randn(1, 1, 28, 28)
output_original = model(input_data)
output_transformed = transformed_model(input_data)
# Dropoutが削除されたので、結果は同じになるはず
print(f"\nOriginal Model Output (first 5 values): {output_original[0, :5]}")
print(f"Transformed Model Output (first 5 values): {output_transformed[0, :5]}")
# グラフを見ると、'dropout1' や 'dropout2' のノードがなくなっていることがわかる
解説
MyDropoutModel
を定義し、いくつかのnn.Dropout
モジュールを含めます。DropoutRemover
はfx.Transformer
を継承します。call_module
メソッドをオーバーライドし、呼び出されたモジュールがnn.Dropout
のインスタンスであるかをチェックします。- もし
nn.Dropout
であれば、そのモジュールが受け取った唯一の入力テンソルargs[0]
をそのまま返します。これにより、FXのグラフではDropoutノードが削除され、その入力が直接その次の演算に渡されるようになります。 assert
文で、Dropoutが通常通り1つの位置引数のみを受け取ることを確認しています。
この例では、nn.Linear
モジュールの出力に、ReLUを適用する前に追加のログ出力(ダミー)を挿入します。
import torch
import torch.nn as nn
import torch.fx as fx
# 1. 元のモデルの定義
class MyLoggingModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 2)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
# 2. Transformerの定義: call_moduleをオーバーライド
class LinearOutputLogger(fx.Transformer):
def call_module(self, target, args, kwargs):
# まず、元のモジュールの呼び出しを実行
output = super().call_module(target, args, kwargs)
# 呼び出されたモジュールがnn.Linearのインスタンスであるかチェック
if isinstance(self.submodules[target], nn.Linear):
# ここでカスタムの前処理/後処理ロジックを挿入
# 例: テンソルの統計をログ出力 (ここではダミーの print)
print(f"DEBUG: Linear module '{target}' output shape: {output.shape}")
# 実際には、この情報を外部ファイルに書き込んだり、特別な演算ノードを追加したりできます。
# 例えば、torch.ops.aten.log_softmax を挿入したい場合など。
# log_softmax_output = torch.nn.functional.log_softmax(output, dim=-1)
# return log_softmax_output
return output # 処理された(またはそのままの)出力を返す
# 3. モデルのトレースと変換
model = MyLoggingModel()
traced_model = fx.symbolic_trace(model)
print("--- 変換前のグラフ ---")
traced_model.graph.print_tabular()
# Transformerを適用
# この例ではグラフは変わらないが、実行時にprint文が挿入されることを期待
transformed_model = LinearOutputLogger(traced_model).transform()
print("\n--- 変換後のグラフ (グラフ構造は変わらないはず) ---")
transformed_model.graph.print_tabular()
# 4. 動作確認
input_data = torch.randn(1, 10)
# 変換後のモデルを実行すると、Linearモジュールの出力時にログが(stdoutに)表示される
output_transformed = transformed_model(input_data)
print(f"Final output shape: {output_transformed.shape}")
MyLoggingModel
を定義し、nn.Linear
モジュールを含めます。LinearOutputLogger
はfx.Transformer
を継承します。call_module
メソッドをオーバーライドします。- まず、
super().call_module(target, args, kwargs)
を呼び出して、元のモジュールの呼び出しをFXグラフに記録させ、その結果のテンソルを取得します。 - 次に、呼び出されたモジュールが
nn.Linear
のインスタンスであるかをチェックし、その出力output
に対するカスタムロジック(ここではprint
文)を実行します。この例ではグラフ構造は変更されませんが、実行時の動作に副作用を追加できます。 - 最後に、
output
をそのまま返します。もしoutput
に対して何らかの変換(例:log_softmax
の適用)を行う場合、その変換後のテンソルを返します。
torch.fx.Transformerの他のノードタイプをオーバーライドする
call_module()
はnn.Module
の呼び出しに対応しますが、FXグラフには他にも異なる種類の操作を表すノードが存在します。Transformer
クラスは、これらのノードタイプに対応するオーバーライド可能なメソッドを提供しており、特定の種類の操作に焦点を当てた変換を行うことができます。
主要なノードタイプとそれに対応するメソッド:
output(self, target, args, kwargs)
: モデルのforward
メソッドからの出力に対応します。 例: 出力テンソルに対して後処理を挿入する。placeholder(self, target, args, kwargs)
: モデルのforward
メソッドへの入力(プレースホルダー)に対応します。 例: 入力テンソルに対して前処理を挿入する。get_attr(self, target, args, kwargs)
:self.param
のように、モデルのパラメータやバッファへのアクセスに対応します。target
は属性の完全修飾名(例:'linear.weight'
)です。 例: 特定のパラメータのロード方法を変更する、特定のバッファの取得を最適化するなど。call_method(self, target, args, kwargs)
:Tensor.add()
,Tensor.mean()
,Tensor.view()
のようなテンソルオブジェクトのメソッド呼び出しに対応します。target
はメソッド名(文字列、例:'mean'
)です。 例:x.view()
をカスタムなシェイプ変更ロジックに置き換える、特定のテンソルメソッド呼び出しを削除するなど。call_function(self, target, args, kwargs)
:torch.add()
,F.relu()
,torch.sigmoid()
のようなグローバル関数(torch
名前空間やtorch.nn.functional
内の関数など)の呼び出しに対応します。 例:torch.add
をカスタム関数に置き換える、F.relu
の前に特別な正規化を挿入するなど。
例
torch.add
をカスタムの足し算に置き換える
import torch
import torch.nn as nn
import torch.fx as fx
class MyFunctionModel(nn.Module):
def forward(self, x, y):
# ここでは torch.add() が call_function ノードになる
return torch.add(x, y) + 1
class CustomAddTransformer(fx.Transformer):
def call_function(self, target, args, kwargs):
if target == torch.add:
print("DEBUG: Replacing torch.add with custom logic.")
# 例: 通常の加算の代わりに、要素ごとの乗算を行う
# 実際には、より複雑な演算やカスタムカーネルの呼び出しが考えられる
return args[0] * args[1]
else:
return super().call_function(target, args, kwargs)
model = MyFunctionModel()
traced_model = fx.symbolic_trace(model)
print("--- 変換前のグラフ ---")
traced_model.graph.print_tabular()
transformed_model = CustomAddTransformer(traced_model).transform()
print("\n--- 変換後のグラフ ---")
transformed_model.graph.print_tabular()
# 動作確認
input1 = torch.tensor([1, 2, 3])
input2 = torch.tensor([4, 5, 6])
output_original = model(input1, input2) # (1+4)+1=6, (2+5)+1=8, (3+6)+1=10 -> [6, 8, 10]
output_transformed = transformed_model(input1, input2) # (1*4)+1=5, (2*5)+1=11, (3*6)+1=19 -> [5, 11, 19]
print(f"\nOriginal Output: {output_original}")
print(f"Transformed Output: {output_transformed}")
torch.fx.Interpreterを使用する
Interpreter
クラスは、FXグラフを「実行」するメカニズムを提供します。Transformer
がグラフを変換して新しいGraphModule
を生成するのに対し、Interpreter
はグラフの各ノードが実行される際にカスタムロジックを挿入することを可能にします。これは、グラフを変換するのではなく、グラフの実行を監視したり、プロファイリングしたり、特定のノードのデバッグ情報を出力したりする場合に特に有用です。
Interpreter
もcall_module()
、call_function()
などのメソッドをオーバーライドできますが、それらは変換ではなく実行時の挙動を制御します。
import torch
import torch.nn as nn
import torch.fx as fx
class MyProfilingModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu(x)
return x
class ProfilingInterpreter(fx.Interpreter):
def run_node(self, node: fx.Node):
# すべてのノードの実行前に呼び出される
print(f"DEBUG: Executing node: {node.op} - {node.target}")
# ここで、ノードのタイプに応じて異なる処理を行う
if node.op == 'call_module':
module_name = node.target
module_instance = self.module.get_submodule(module_name)
print(f" -> Calling module: {module_name} ({type(module_instance).__name__})")
elif node.op == 'call_function':
print(f" -> Calling function: {node.target.__name__}")
elif node.op == 'call_method':
print(f" -> Calling method: {node.target}")
# 元のノードの実行ロジックを呼び出す
# これが実際の計算を実行し、結果のテンソルを返す
result = super().run_node(node)
# 実行後の処理
if isinstance(result, torch.Tensor):
print(f" -> Output shape: {result.shape}")
return result
model = MyProfilingModel()
traced_model = fx.symbolic_trace(model)
input_data = torch.randn(1, 10)
print("--- Interpreter を使った実行 ---")
interpreter = ProfilingInterpreter(traced_model)
output = interpreter.run(input_data) # run() メソッドがグラフの実行を開始
print(f"Final output: {output.shape}")
解説
ProfilingInterpreter
は、各ノードが実行される際にrun_node()
メソッドが呼び出されます。この中で、node.op
やnode.target
を使ってノードの種類を判別し、カスタムのログ出力などの処理を行うことができます。super().run_node(node)
を呼び出すことで、FXのデフォルトのノード実行ロジックが実行され、適切なテンソルが返されます。
replace_pattern
は、より高レベルのAPIで、特定のグラフパターンを別のグラフパターンに置き換えるために使用されます。これは、例えば、Conv
とBatchNorm
の融合のような、特定のモジュールや関数の組み合わせをより効率的な単一の演算に置き換えたい場合に非常に強力です。
replace_pattern
は内部的にcall_module
やcall_function
などのノードを識別し、指定されたパターンに合致する部分を置き換えます。手動でcall_module
をオーバーライドして複雑なパターンマッチングを行うよりも、コードが簡潔になることが多いです。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fx as fx
from torch.fx.passes.utils.fuser_utils import fuse_conv_bn_eval
# 1. 元のモデルの定義 (Conv + BatchNorm のパターンを含む)
class MyConvBnModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
self.linear = nn.Linear(16 * 10 * 10, 10) # 例のため適当なサイズ
def forward(self, x):
x = self.conv(x)
x = self.bn(x) # Conv + BatchNorm のパターン
x = self.relu(x)
# Flatten and then apply linear layer for demonstration
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
# 2. パターンマッチングと置き換えを行う関数 (通常は eval モードで行われる最適化)
# これは PyTorch の内部パスとして提供されているものを使うのが一般的
# fuse_conv_bn_eval は nn.Conv2d と nn.BatchNorm2d のシーケンスを単一の nn.Conv2d に変換する
# (ここでは簡略化のため、自作のフュージョン関数は省略し、説明に留めます)
# replace_pattern のデモのために、より簡単な置き換えパターンを作成
# 例えば、`linear(x) + y` を `linear(x * y)` に置き換える
class Pattern(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5) # ダミー、実際にはマッチするモジュール
def forward(self, x, y):
# この forward がグラフとしてトレースされ、パターンとなる
return self.linear(x) + y
class Replacement(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5) # ダミー、実際にはマッチするモジュール
def forward(self, x, y):
# この forward がグラフとしてトレースされ、置き換え後のパターンとなる
return self.linear(x * y)
# 実際のモデル
class MyPatternModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 2)
def forward(self, x, y):
# linear1(x) + y のパターンをマッチさせたい
return self.linear1(x) + y
# 3. モデルのトレースと replace_pattern の適用
model = MyPatternModel()
traced_model = fx.symbolic_trace(model)
print("--- 変換前のグラフ ---")
traced_model.graph.print_tabular()
# パターンと置き換え用のダミーモジュールをトレース
pattern_traced = fx.symbolic_trace(Pattern())
replacement_traced = fx.symbolic_trace(Replacement())
# パターンマッチングと置き換えを実行
# ここでは、MyPatternModel の linear1 に対応する Pattern の linear を明示的にマップ
# (実際には、名前ベースのマッチングや型ベースのマッチングが行われる)
# NOTE: replace_pattern は、pattern と replacement の GraphModule が持つモジュール名を
# マッチさせる必要がある。ここでは、pattern.linear と replacement.linear を
# model.linear1 にマッピングする。
new_graph = fx.replace_pattern(traced_model, pattern_traced, replacement_traced)
# 新しい GraphModule を作成 (GraphModule は nn.Module のサブクラス)
transformed_model = fx.GraphModule(traced_model, new_graph)
print("\n--- 変換後のグラフ ---")
transformed_model.graph.print_tabular()
# 動作確認 (入出力の変更に注意)
input_x = torch.randn(1, 10)
input_y = torch.randn(1, 10)
# transformed_model の入力は (x, y) なので、Pattern モデルの入力に合わせて呼び出す
# Modelのforward定義を変更していないので、型エラーになるが、概念的な説明
# output_original = model(input_x, input_y)
# output_transformed = transformed_model(input_x, input_y)
解説
replace_pattern
は、call_module
レベルでの手動のノード操作よりも、高レベルで安全な方法でグラフを変換します。特に、複数の操作からなる複合的なパターンを置き換えたい場合に非常に有効です。call_module
や他のノードのオーバーライドは、より低レベルで粒度の細かい制御が必要な場合に適しています。
torch.fx.Transformer.call_module()
は、nn.Module
の呼び出しを対象とした強力な変換フックですが、FXには他にも多くの代替手段があります。
replace_pattern
: 複数のノードからなる複雑なパターンを置き換えたい場合に、高レベルで宣言的な方法を提供します。Interpreter
: グラフの変換ではなく、グラフの実行時の振る舞いを監視したり、デバッグ情報を収集したりするのに適しています。Transformer
の他のメソッド (call_function
,call_method
,get_attr
など): 特定の種類のノード操作に特化したい場合に、より直接的な制御を提供します。