PyTorchのFXを使いこなす!torch.fx.wrap()の基本と活用法

2025-05-31

FXは、PyTorchモデルのPythonコードをシンボリックトレースすることで、計算グラフ(Graph IR)を生成し、そのグラフを最適化したり変換したりするための強力なツールです。通常、FXはモデルのforwardメソッド内のすべてのPyTorch操作を詳細に記録しようとします。

しかし、以下のようなケースでは、FXによる詳細なトレースが望ましくない、または不可能になることがあります。

  • 特定の関数の動作を「ブラックボックス」として扱いたい場合
    最適化の対象外としたい、あるいは内部の詳細を隠蔽したい関数がある場合。
  • カスタムのC++/CUDAカーネル
    PyTorchの標準的な操作として定義されていない、カスタムのC++やCUDAで書かれたカーネルを呼び出す場合、FXはそれらを認識できません。
  • 外部ライブラリの呼び出し
    PyTorchのテンソル操作ではない、NumPyのような外部ライブラリの関数や、Pythonの標準ライブラリの関数(例: len())を呼び出す場合、FXはその内部の操作をトレースできません。
  • 動的な制御フロー (Dynamic Control Flow)
    if文やforループなど、入力データによって処理が変わるようなロジックが含まれる場合、FXはこれらの動的な挙動を静的なグラフとして表現するのが難しいことがあります。

このような場合にtorch.fx.wrap()を使用すると、指定された関数やモジュールがFXのシンボリックトレースの対象から除外され、代わりに「単一のノード」としてグラフに記録されます。これにより、FXはグラフの連続性を保ちつつ、トレースできない部分やトレースしたくない部分を「葉(leaf)ノード」として扱うことができます。

torch.fx.wrap() の使い方

torch.fx.wrap()はデコレータとして、あるいは関数として使用できます。

  1. デコレータとして関数に適用する例

    import torch
    import torch.fx as fx
    import math
    
    @fx.wrap
    def my_custom_operation(x, y):
        # この関数内の処理はFXによって詳細にはトレースされず、
        # `my_custom_operation`という単一のノードとしてグラフに記録されます。
        return torch.sqrt(x) + math.log(y)
    
    class MyModule(torch.nn.Module):
        def forward(self, x, y):
            a = x * 2
            b = y + 1
            c = my_custom_operation(a, b)
            return c
    
    model = MyModule()
    traced_model = fx.symbolic_trace(model)
    
    print(traced_model.graph)
    # 出力例:
    # graph():
    #     %x : [num_users=1] = placeholder[target=x]
    #     %y : [num_users=1] = placeholder[target=y]
    #     %a : [num_users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
    #     %b : [num_users=1] = call_function[target=operator.add](args = (%y, 1), kwargs = {})
    #     %c : [num_users=1] = call_function[target=my_custom_operation](args = (%a, %b), kwargs = {}) # ここが単一ノードになる
    #     return %c
    

    上記の例では、my_custom_operation関数が@fx.wrapでデコレートされているため、その内部のtorch.sqrtmath.logといった個々の操作はFXのグラフには現れません。代わりに、my_custom_operationというノードがグラフに直接記録されます。

  2. import torch
    import torch.fx as fx
    
    # 例えば、Pythonの標準関数をラップしたい場合
    # fx.wrap(len) をどこか適切な場所(通常はファイルのトップレベル)に記述します。
    # これにより、FXがlen()をシンボリックトレースしようとしたときにエラーにならないようにします。
    fx.wrap('len')
    
    class AnotherModule(torch.nn.Module):
        def forward(self, x):
            length = len(x.shape) # len() がグラフに単一ノードとして記録される
            return x + length
    
    model = AnotherModule()
    traced_model = fx.symbolic_trace(model)
    print(traced_model.graph)
    
  • グローバル関数または文字列名
    torch.fx.wrap()は、グローバル関数か、関数の文字列名を引数に取る必要があります。メソッドに直接デコレータとして適用することはできません(ただし、FXの新しい機能では一部の状況で可能になっている場合もあります)。
  • 引数と戻り値の型
    wrapされた関数への入力と出力は、FXが扱える型(テンソル、数値、タプル、リストなど)である必要があります。特に、torch.fx.Proxyオブジェクトを処理できる必要があります。
  • 内部はトレースされない
    wrapされた関数やモジュール内のPyTorch操作は、FXのグラフには詳細に現れません。これは、FXがその部分を最適化したり変更したりできないことを意味します。


torch.fx.wrap() に関連するよくあるエラーとトラブルシューティング

RuntimeError: '...' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('...') at module scope.

これは最も一般的なエラーメッセージの一つです。FXがトレースしようとした関数が、デフォルトではトレースできない(またはトレースしたくない)Pythonの組み込み関数や、PyTorchではない外部ライブラリの関数である場合に発生します。

原因

  • torch.fx.wrap() の呼び出しが適切に行われていない、または呼び出すべき場所ではない。
  • NumPyなどのPyTorchテンソルではないオブジェクトを扱う外部ライブラリの関数を呼び出している。
  • len(), isinstance(), print(), assert などのPython組み込み関数を使用している。

トラブルシューティング

  • torch.compile との組み合わせ
    torch.compile を使用している場合、内部でFXが使われているため、同様のエラーに遭遇することがあります。解決策は同じく torch.fx.wrap() を適切に呼び出すことです。

  • デコレータとして使用する場合
    ユーザー定義関数をラップする場合は、デコレータとして@fx.wrapを使用するのが最も一般的で推奨される方法です。

    import torch
    import torch.fx as fx
    
    @fx.wrap # ここでデコレータとして適用
    def custom_op(x):
        # この内部はトレースされず、単一ノードとして記録される
        return x.sum() / x.numel()
    
    class MyModule(torch.nn.Module):
        def forward(self, x):
            return custom_op(x)
    
    model = MyModule()
    traced_model = fx.symbolic_trace(model)
    print(traced_model.graph)
    
  • torch.fx.wrap() をモジュールスコープで呼び出す
    エラーメッセージが示唆するように、対象の関数をFXにラップさせるには、torch.fx.wrap() をその関数が定義されているモジュールのトップレベル(グローバルスコープ)で呼び出す必要があります。

    import torch
    import torch.fx as fx
    import math
    
    # 正しい例: モジュールスコープでwrapを呼び出す
    fx.wrap('len')
    fx.wrap('math.log') # math.logのように、モジュール内の関数を指定することも可能
    
    def my_function(x):
        a = len(x) # lenはこれでトレースされる
        b = math.log(a) # math.logもこれでトレースされる
        return x * b
    
    model = my_function # 関数を直接トレースすることも可能
    traced_model = fx.symbolic_trace(model)
    print(traced_model.graph)
    

torch.fx.wrap() が機能しない、または期待通りのグラフが生成されない

wrap を適用したにもかかわらず、関数が詳細にトレースされてしまったり、逆に全く認識されないノードになってしまったりするケースです。

原因

  • 異なるモジュールからのインポート
    ラップしたい関数が別のファイルやモジュールで定義されている場合、fx.wrap('function_name') のように文字列で指定するだけでは不十分な場合があります。その関数が実際にインポートされ、Pythonのグローバル名前空間で利用可能になっている必要があります。
  • メソッドのラッピング
    torch.fx.wrap は通常、グローバル関数を対象とします。クラスのインスタンスメソッドを直接 wrap するのは難しい場合があります(特に古いPyTorchバージョンでは)。
  • 動的な制御フロー
    wrap された関数内に、FXが静的なグラフとして表現できない動的な制御フロー(データ依存の if 文や for ループ)が含まれている場合。wrap はその関数を単一ノードとして扱いますが、内部の動作を完全に無視するわけではありません。wrap は「その関数の呼び出し自体をノードにする」だけであり、その内部のロジックがトレースに適さない場合は、それでも問題が発生する可能性があります。

トラブルシューティング

  • インポートパスの確認
    fx.wrap('my_module.my_function') のように、完全な修飾名(fully qualified name)で指定するか、関数を現在のモジュールに直接インポートしてから wrap します。

    # my_utils.py
    import torch
    def helper_function(x):
        return x.sum()
    
    # main.py
    import torch
    import torch.fx as fx
    from my_utils import helper_function
    
    fx.wrap('my_utils.helper_function') # これでOK
    
    class MyModel(torch.nn.Module):
        def forward(self, x):
            return helper_function(x)
    
    model = MyModel()
    traced_model = fx.symbolic_trace(model)
    print(traced_model.graph)
    
  • メソッドのラッピングの代替手段
    クラスメソッドをラップしたい場合は、そのメソッドをクラスの外にヘルパー関数として定義し、それを wrap するか、モジュール全体をカスタムのFX Tracer で処理することを検討します。

  • FXの基本的な制限を理解する
    torch.fx.wrap は、動的な制御フローの問題を根本的に解決するものではありません。あくまで「この関数はブラックボックスとして扱ってね」という指示です。動的な挙動が問題なら、モデルの設計を見直すか、FX以外の方法を検討する必要があります。

AttributeError: 'Proxy' object has no attribute '...'

wrap された関数に渡された引数や、その関数内で操作されるオブジェクトが、FXの Proxy オブジェクトであり、通常のPythonオブジェクトのように振る舞わない場合に発生します。

原因

  • Proxy オブジェクトの属性に直接アクセスしようとしているが、それがFXのグラフには存在しない場合。
  • wrap された関数が、Proxy オブジェクトに対してPythonの組み込み関数や、テンソル以外の操作(例: list.append(), dict.keys(), len() をテンソルではないリストに適用するなど)を実行しようとしているが、これらの操作がFXによって適切に処理できない。

トラブルシューティング

  • len() のような操作
    len(tensor) は通常FXでトレース可能ですが、len(list_of_tensors) のような場合、list_of_tensorsProxy オブジェクトのリストであるため、問題になることがあります。この場合、len を別途ラップする必要があります。
  • wrap された関数の入力/出力をテンソルに限定する
    wrap された関数がPyTorchテンソルのみを引数として受け取り、PyTorchテンソルのみを返すように設計することを強く推奨します。もし非テンソルデータが必要であれば、FXグラフの外でそれらを処理するか、トレース後に別途処理するロジックを検討します。
  • Proxy オブジェクトの制限を理解する
    FXはPyTorchテンソルとそれに関連する操作をトレースすることに特化しています。Proxy オブジェクトは実際の値ではなく、グラフのノードを表現するものです。wrap された関数内で、Proxy に直接対応しないPythonの標準的なコンテナ操作や、非テンソルデータ型へのアクセスは避けるべきです。

torch.fx.wrap を使用するとパフォーマンスが低下する

torch.fx.wrap() は最適化の機会を失う可能性があります。

原因

  • wrap された関数はFXのグラフで単一のノードとして扱われるため、その内部の操作はFXによる最適化(オペレーターフュージョンなど)の対象外となります。

トラブルシューティング

  • wrap の範囲を最小限にする
    最適化を阻害しないよう、wrap を適用する範囲は必要最小限に留めるべきです。
  • 本当に wrap が必要か再検討する
    その関数が本当にトレースできない動的なロジックを含んでいるのか、あるいはPyTorchテンソル操作のみで構成されているが、単にFXが認識していないだけなのかを検討します。後者の場合、カスタムの Tracer を実装するか、PyTorchのコア開発チームに機能リクエストを出すことを検討します。
  • 公式ドキュメントとフォーラムを参照する
    PyTorchの公式ドキュメント(特にFXのセクション)は非常に詳細です。また、PyTorchフォーラムやGitHub Issuesで同様の問題が報告されていないか検索するのも有効です。
  • PyTorchのバージョンを確認する
    FXは比較的新しい機能であり、PyTorchのバージョンによって挙動やサポートされる機能が異なる場合があります。最新の安定版を使用しているか確認し、必要であればアップデートを検討してください。
  • 最小限の再現コードを作成する
    問題を切り分けるために、問題が発生する最小限のコードスニペットを作成してみてください。これにより、複雑なモデルの中から問題の箇所を特定しやすくなります。
  • エラーメッセージをよく読む
    PyTorchのエラーメッセージは非常に具体的で役立つ情報を含んでいることが多いです。特に、RuntimeError のメッセージは、何がサポートされていないのか、どのように解決すればよいのかを直接示唆していることがあります。


torch.fx.wrap()は、PyTorchのFX (Functional eXchange) によるモデルのシンボリックトレースにおいて、特定の関数を「ブラックボックス」として扱い、その内部を詳細にトレースしないようにするための機能です。

基本的な考え方

FXは通常、モデルのforwardメソッド内のすべてのPyTorch演算を詳細に記録し、グラフとして表現します。しかし、以下のような場合には、FXがそのままトレースできない、または詳細なトレースが望ましくないことがあります。

  • 最適化の対象外としたい、または内部実装を隠蔽したいカスタム関数
  • データ依存の動的な制御フロー(非常に複雑なif文やforループ)
  • 外部ライブラリの関数(例: NumPyの関数)
  • Pythonの組み込み関数(例: len(), isinstance()

このような場合にtorch.fx.wrap()を使うと、指定された関数がFXグラフ内で単一のノードとして表現され、その内部の詳細は無視されます。

プログラミング例

例1: Pythonの組み込み関数をラップする

len()関数は、そのままではFXのトレース中にエラーになることがあります。これをtorch.fx.wrap()でラップすることで、グラフに単一のノードとして含めることができます。

import torch
import torch.fx as fx

# torch.fx.wrap()はモジュールスコープ(グローバルスコープ)で呼び出す必要があります。
# これにより、FXは len() が呼び出されたときにこれを単一のノードとして扱います。
fx.wrap('len') # 文字列で関数名を指定

class MyModuleWithLen(torch.nn.Module):
    def forward(self, x):
        # x.shape はテンソルではないタプルを返すため、len() はPython組み込みの len() が呼ばれます。
        shape_len = len(x.shape)
        return x * shape_len

# モデルのインスタンス化
model = MyModuleWithLen()

# シンボリックトレース
traced_model = fx.symbolic_trace(model)

# 生成されたグラフを表示
print("--- 例1: len() をラップした場合のグラフ ---")
print(traced_model.graph)

"""
出力例:
--- 例1: len() をラップした場合のグラフ ---
graph():
    %x : [num_users=1] = placeholder[target=x]
    %shape : [num_users=1] = call_method[target=shape](args = (%x,), kwargs = {})
    %shape_len : [num_users=1] = call_function[target=len](args = (%shape,), kwargs = {}) # len() が単一ノードとして記録される
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, %shape_len), kwargs = {})
    return %mul
"""

# トレースされたモデルを実行してみる
dummy_input = torch.randn(2, 3, 4)
output_original = model(dummy_input)
output_traced = traced_model(dummy_input)
print(f"オリジナルモデルの出力: {output_original.shape}")
print(f"トレースされたモデルの出力: {output_traced.shape}")
assert torch.allclose(output_original, output_traced)
print("出力が一致しました。")

例2: ユーザー定義関数をデコレータとしてラップする

カスタムのヘルパー関数があり、その内部の詳細をFXにトレースさせたくない場合に、デコレータ@fx.wrapを使用します。

import torch
import torch.fx as fx
import math # mathモジュールもPython標準なので、直接トレースできません

# @fx.wrap デコレータを使ってカスタム関数をラップします
@fx.wrap
def custom_complex_op(tensor_a, tensor_b):
    # この関数内の操作(torch.sqrt, math.logなど)は、FXグラフには個々のノードとしては現れません。
    # 代わりに、custom_complex_op という単一のノードとして記録されます。
    intermediate = torch.sqrt(tensor_a) * 2.0
    result = intermediate + math.log(tensor_b.sum().item()) # .item()でPythonのfloatに変換し、math.logを呼び出す
    return result

class MyModuleWithCustomOp(torch.nn.Module):
    def forward(self, x, y):
        # ラップされたカスタム関数を呼び出します
        output = custom_complex_op(x, y)
        return output * 3.0

model = MyModuleWithCustomOp()
traced_model = fx.symbolic_trace(model)

print("\n--- 例2: ユーザー定義関数をデコレータでラップした場合のグラフ ---")
print(traced_model.graph)

"""
出力例:
--- 例2: ユーザー定義関数をデコレータでラップした場合のグラフ ---
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %output : [num_users=1] = call_function[target=custom_complex_op](args = (%x, %y), kwargs = {}) # ここが単一ノードになる
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%output, 3.0), kwargs = {})
    return %mul
"""

# トレースされたモデルを実行してみる
dummy_x = torch.randn(5)
dummy_y = torch.abs(torch.randn(5)) + 1e-6 # logのために正の値にする
output_original = model(dummy_x, dummy_y)
output_traced = traced_model(dummy_x, dummy_y)
print(f"オリジナルモデルの出力: {output_original.shape}")
print(f"トレースされたモデルの出力: {output_traced.shape}")
assert torch.allclose(output_original, output_traced)
print("出力が一致しました。")

例3: 外部ライブラリの関数をラップする(NumPyの例)

通常、NumPyの操作はPyTorchのテンソル操作とは異なるため、FXは直接トレースできません。しかし、torch.fx.wrap()を使うことで、NumPyの関数呼び出しをグラフに含めることができます。

注意点
torch.fx.wrap()を使って外部ライブラリの関数をラップする場合、その関数がPyTorchテンソルとPythonの組み込み型(数値、タプルなど)を引数として受け取り、同様の型を返す必要があります。NumPyの関数がPyTorchテンソルを直接受け取ると、自動的にNumPy配列に変換されることがありますが、その挙動は注意深く確認する必要があります。

import torch
import torch.fx as fx
import numpy as np

# NumPyの関数をラップします
# NumPy関数は通常Pythonオブジェクト(例えばPython floatやint)を扱うため、
# PyTorchテンソルを扱うFXのProxyオブジェクトと相性が悪いことがあります。
# ここでは、テンソルをPythonの数値に変換してからNumPy関数に渡す例を示します。
fx.wrap('numpy.mean') # 文字列で指定

class MyModuleWithNumpy(torch.nn.Module):
    def forward(self, x):
        # x.mean().item() でテンソルの平均値をPythonのfloatに変換
        # その後、numpy.mean() を呼び出すことで、トレース可能にする
        mean_val_np = np.mean(x.mean().item()) # numpy.mean() が単一ノードになる
        return x * mean_val_np

model = MyModuleWithNumpy()
traced_model = fx.symbolic_trace(model)

print("\n--- 例3: NumPy関数をラップした場合のグラフ ---")
print(traced_model.graph)

"""
出力例:
--- 例3: NumPy関数をラップした場合のグラフ ---
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mean : [num_users=1] = call_method[target=mean](args = (%x,), kwargs = {})
    %item : [num_users=1] = call_method[target=item](args = (%mean,), kwargs = {})
    %mean_val_np : [num_users=1] = call_function[target=numpy.mean](args = (%item,), kwargs = {}) # numpy.meanが単一ノードに
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, %mean_val_np), kwargs = {})
    return %mul
"""

# トレースされたモデルを実行してみる
dummy_input = torch.randn(3, 3)
output_original = model(dummy_input)
output_traced = traced_model(dummy_input)
print(f"オリジナルモデルの出力: {output_original.shape}")
print(f"トレースされたモデルの出力: {output_traced.shape}")
assert torch.allclose(output_original, output_traced)
print("出力が一致しました。")
  • 最適化の機会の喪失: ラップされた関数はFXによる詳細な最適化の対象外となります。したがって、パフォーマンスが重要なコアな計算ロジックには、できるだけwrapを使用しない方が良いでしょう。wrapは、どうしてもトレースできない部分や、最適化が不要な(あるいは難しい)部分に限定して使用することが推奨されます。
  • 引数と戻り値: ラップされた関数は、FXが扱える型の引数(torch.Tensor、数値、タプルなど)を受け取り、同様の型を返す必要があります。特に、PyTorchのテンソルを返すことが期待されます。そうでない場合、後続の操作でエラーが発生する可能性があります。
  • デコレータ: ユーザー定義関数に対しては、@fx.wrap デコレータを使用するのが最も一般的で分かりやすい方法です。これも関数定義の直前(通常はモジュールスコープ)に置かれます。
  • グローバルスコープでの呼び出し: fx.wrap('function_name') の形式で関数名を文字列で指定する場合、その呼び出しはPythonファイルのトップレベル(モジュールスコープ)で行う必要があります。これは、FXがトレースを開始する前に、どの関数をラップするかを把握しておく必要があるためです。


カスタム torch.fx.Tracer の実装

torch.fx.symbolic_trace() の内部では、torch.fx.Tracer クラスが使用されています。この Tracer クラスをサブクラス化することで、トレースの挙動をより細かく制御できます。

目的

  • 特定の操作を異なる方法で表現する。
  • 特定の型のオブジェクトがグラフにどのように記録されるかをカスタマイズする。
  • 特定のモジュールや関数を、wrap() を使わずにリーフノード(トレースの終端)として扱う。

利点

  • 特定の種類のモジュール(例: サードパーティライブラリの特定の層)を自動的にリーフとして扱うルールを設定できる。
  • wrap() よりも柔軟性が高く、トレースのロジック全体をカスタマイズできる。

欠点

  • 汎用的な解決策ではなく、特定のニーズに合わせてカスタマイズが必要。
  • 複雑さが増す。FXの内部動作に関する深い理解が必要。


ある種のモジュール(例: MyNonTraceableModule)を、その内部をトレースせずに常にリーフノードとして扱いたい場合。

import torch
import torch.nn as nn
import torch.fx as fx

# トレースしたくない、カスタムの(またはサードパーティの)モジュール
class MyNonTraceableModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.randn(10))

    def forward(self, x):
        # ここで複雑な非PyTorch操作や動的なロジックがあると仮定
        result = x + self.param.mean()
        # 例えば、外部のC++ライブラリを呼び出すようなイメージ
        import math
        return result * math.sin(x.sum().item())

# カスタムTracerの定義
class CustomTracer(fx.Tracer):
    def is_leaf_module(self, m: torch.nn.Module, qualname: str) -> bool:
        # MyNonTraceableModule のインスタンスをリーフとして扱う
        if isinstance(m, MyNonTraceableModule):
            return True
        # デフォルトの挙動(nn.ModuleList, nn.Sequentialなどはトレースし、
        # nn.Linear, nn.Conv2dなどはリーフとして扱わない)に従う
        return super().is_leaf_module(m, qualname)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 10)
        self.custom_module = MyNonTraceableModule()
        self.linear2 = nn.Linear(10, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.custom_module(x) # このモジュールがリーフとして扱われる
        x = self.linear2(x)
        return x

model = MyModel()
# カスタムTracerを使用してモデルをトレース
tracer = CustomTracer()
traced_model = tracer.trace(model)
graph_module = fx.GraphModule(model, traced_model)

print("--- 1. カスタム Tracer を使用した場合のグラフ ---")
print(graph_module.graph)

"""
出力例の一部:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %linear1 : [num_users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
    %custom_module : [num_users=1] = call_module[target=custom_module](args = (%linear1,), kwargs = {}) # MyNonTraceableModule が単一ノードに
    %linear2 : [num_users=1] = call_module[target=linear2](args = (%custom_module,), kwargs = {})
    return %linear2
"""

サブグラフの置き換え (torch.fx.subgraph_rewriter.replace_pattern)

特定のパターン(オペレーションのシーケンス)をFXグラフ内で見つけ、それを別のオペレーションのシーケンス(または単一の関数呼び出し)に置き換えることができます。これはwrap()とは少し目的が異なりますが、グラフを簡素化したり、特定の非トレース可能パターンを事前に処理したりするのに役立ちます。

目的

  • wrap()だけでは解決できないような、特定の構造を持つ非トレース可能コードをグラフで表現できる形に変換する。
  • 複雑な複数の操作を、よりシンプルな単一のノードに集約する(ただし、その単一ノードはFXが認識できるものでなければならない)。
  • モデル内の特定の計算パターンを、最適化されたカスタム実装に置き換える。

利点

  • 既知の非効率なパターンを最適化されたパターンに自動的に置き換えることができる。
  • グラフのセマンティクスを維持しつつ、柔軟な変換が可能。

欠点

  • 動的な制御フローを持つ非トレース可能なコードを、この方法だけで完全に解決するのは難しい場合がある。
  • wrap()のように「どんなものでもブラックボックスにする」わけではない。置き換え元のパターンと置き換え先のパターンが明確に定義されている必要がある。


x + y + z というパターンを torch.add(torch.add(x, y), z) から、より効率的な(または異なる)x + (y + z) に置き換える、あるいはカスタムの統合されたAdd関数に置き換える。

import torch
import torch.nn as nn
import torch.fx as fx
from torch.fx.subgraph_rewriter import replace_pattern

class MyAddModule(nn.Module):
    def forward(self, x, y, z):
        a = x + y
        b = a + z # このパターンを置き換えたい
        return b

# 置き換えたいパターンを定義する関数
def pattern_to_match(x, y, z):
    a = x + y
    b = a + z
    return b

# 置き換え後のパターンを定義する関数
def replacement_pattern(x, y, z):
    # 例えば、3つのテンソルを一度に追加するカスタム関数があるとする
    # あるいは、異なる結合順序にする
    return x + (y + z) # or custom_add3_op(x, y, z)

model = MyAddModule()
traced_model = fx.symbolic_trace(model)

print("--- 2. サブグラフ置き換え前のグラフ ---")
print(traced_model.graph)

# パターンを置き換え
replace_pattern(traced_model, pattern_to_match, replacement_pattern)

print("\n--- 2. サブグラフ置き換え後のグラフ ---")
print(traced_model.graph)

"""
出力例の一部:
--- 2. サブグラフ置き換え前のグラフ ---
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %z : [num_users=1] = placeholder[target=z]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%add, %z), kwargs = {})
    return %add_1

--- 2. サブグラフ置き換え後のグラフ ---
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %z : [num_users=1] = placeholder[target=z]
    %add : [num_users=1] = call_function[target=operator.add](args = (%y, %z), kwargs = {}) # 結合順序が変わった
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %add), kwargs = {})
    return %add_1
"""

dummy_x, dummy_y, dummy_z = torch.randn(2), torch.randn(2), torch.randn(2)
output_original = model(dummy_x, dummy_y, dummy_z)
output_traced_rewritten = traced_model(dummy_x, dummy_y, dummy_z)
assert torch.allclose(output_original, output_traced_rewritten)
print("出力が一致しました。")

torch.autograd.Function を使用したカスタムオペレーションの定義

もし、C++やCUDAで書かれた独自のバックエンド操作があり、それをPyTorchのグラフに統合したい場合は、torch.autograd.Function を使用してカスタムオペレーションを定義するのが最も強力な方法です。

目的

  • C++/CUDAカーネルなど、Pythonでは実装できない高性能な計算ロジックを組み込む。
  • PyTorchのテンソル計算グラフに完全に統合された、カスタムのフォワードパスとバックワードパスを持つ操作を作成する。

利点

  • 最高のパフォーマンスと柔軟性を提供。
  • wrap() とは異なり、カスタム操作の内部をFXグラフに詳細に記録しないものの、その操作自体は完全にFX/Autogradシステムに認識されるため、最適化や勾配計算の恩恵を受けられる。

欠点

  • 純粋なPython関数を「トレースしたくない」という目的には大げさすぎる。
  • 実装の複雑さが最も高い。C++/CUDAの知識が必要となる場合が多い。


(これは概念的なもので、実際のC++カーネルの実装は省略)

import torch

class CustomExpMinusOne(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # ここでカスタムのC++/CUDAカーネルを呼び出すと仮定
        # 簡単な例として、Pythonで実装
        ctx.save_for_backward(x)
        return torch.exp(x) - 1

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        # カスタムのC++/CUDAカーネルを呼び出すと仮定
        # 簡単な例として、Pythonで実装
        grad_x = grad_output * torch.exp(x)
        return grad_x

class MyModuleWithCustomAutograd(torch.nn.Module):
    def forward(self, x):
        return CustomExpMinusOne.apply(x)

model = MyModuleWithCustomAutograd()
traced_model = fx.symbolic_trace(model)

print("\n--- 3. torch.autograd.Function を使用した場合のグラフ ---")
print(traced_model.graph)

"""
出力例:
--- 3. torch.autograd.Function を使用した場合のグラフ ---
graph():
    %x : [num_users=1] = placeholder[target=x]
    %apply : [num_users=1] = call_function[target=CustomExpMinusOne.apply](args = (%x,), kwargs = {}) # apply が単一ノードとして記録される
    return %apply
"""

dummy_input = torch.randn(5)
output_original = model(dummy_input)
output_traced = traced_model(dummy_input)
print(f"オリジナルモデルの出力: {output_original.shape}")
print(f"トレースされたモデルの出力: {output_traced.shape}")
assert torch.allclose(output_original, output_traced)
print("出力が一致しました。")

torch.compile はモデル全体のコンパイルを試みますが、一部の動的なPythonコードやサポートされていない操作に遭遇した場合、自動的に「ブレーク」(部分的にコンパイルを諦める)して、残りの部分をPythonで実行します。これはwrap()の目的と似ていますが、より高レベルで自動的に行われます。

目的

  • モデル全体を自動的に最適化したいが、一部の非トレース可能コードを強制的にPythonフォールバックさせたい場合。

利点

  • コンパイラが最適な分割点を自動で探す。
  • ユーザーが明示的にwrap()を記述する必要がないことが多い。

欠点

  • fullgraph=Trueを指定しない限り、ブレークによってコンパイルの恩恵が限定される可能性がある。
  • wrap()のように特定の関数を明示的にマークするわけではないので、どこでブレークが発生するかを正確に制御できない場合がある。


import torch
import torch.nn as nn
import math

class MyDynamicModule(nn.Module):
    def forward(self, x):
        # データに依存する動的な制御フロー(例として)
        if x.mean().item() > 0:
            y = x * 2
        else:
            y = x / 2

        # 外部ライブラリの呼び出し
        # torch.compile はこれをフォールバックさせる可能性が高い
        z = y + math.sqrt(y.sum().item())
        return z

model = MyDynamicModule()

# torch.compile を使用(デフォルトでは fullgraph=False の挙動)
# `torch._dynamo.explain` を使って、どこでブレークが発生するかを確認できる
# from torch._dynamo.explain import explain
# explain(model, torch.randn(5))

compiled_model = torch.compile(model)

print("\n--- 4. torch.compile を使用した場合 ---")
dummy_input = torch.randn(5)
output_original = model(dummy_input)
output_compiled = compiled_model(dummy_input)
print(f"オリジナルモデルの出力: {output_original.shape}")
print(f"コンパイルされたモデルの出力: {output_compiled.shape}")
assert torch.allclose(output_original, output_compiled)
print("出力が一致しました。")
# 実際にグラフが表示されるわけではないが、内部で一部がフォールバックされる

torch.fx.wrap()は、FXグラフにおける特定の関数呼び出しを「単一のノード」として扱うための最も直接的な方法です。しかし、状況によっては、以下のような代替手段がより適切かもしれません。

  • torch.compile: より高レベルで自動的に非トレース可能な部分を処理し、コンパイルの恩恵を受けたい場合。
  • torch.autograd.Function: C++/CUDAなどで実装されたカスタム操作をPyTorchのグラフに完全に統合し、勾配計算も行いたい場合。
  • サブグラフの置き換え: グラフ内の特定のパターンを別のパターンに変換・最適化したい場合。
  • カスタム Tracer: より細かくトレースの挙動を制御したい場合や、特定のモジュールタイプ全体をリーフとして扱いたい場合。