torch.fx.Tracer.create_proxy()

2025-05-31

  1. プロキシオブジェクトの生成: torch.fx.Tracerは、PyTorchモデルのforwardメソッドを実行する際に、実際のテンソル値の代わりに「プロキシ (Proxy)」と呼ばれる特殊なオブジェクトを渡します。create_proxy()は、このプロキシオブジェクトを生成する役割を担います。

  2. 操作の記録: 生成されたプロキシオブジェクトに対してPyTorchの演算(torch.addnn.Linearの呼び出しなど)が実行されると、Tracerはこれらの操作を記録し、内部的な中間表現であるGraphにノードとして追加します。create_proxy()は、この記録プロセスにおいて、新しい操作の結果を表す新しいプロキシを生成するために呼び出されます。

  3. シンボリック実行の実現: Tracerは、実際の値ではなくプロキシを使ってモデルの実行を「シンボリックに」行います。これにより、モデルの具体的な入力値に依存せず、実行される計算グラフの構造を抽出することができます。create_proxy()は、このシンボリック実行の過程で、各ステップの結果を抽象的に表現するプロキシを動的に生成します。

具体的な動作のイメージ

例えば、以下のような簡単なPyTorchモデルがあるとします。

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def forward(self, x):
        return x + 1

torch.fx.symbolic_trace(MyModule()) のようにこのモジュールをトレースすると、内部でTracerが動きます。

  1. MyModule().forward(x) が呼び出される際、x は実際のテンソルではなく、Tracerによって生成された「プロキシ」として渡されます。このプロキシの生成にcreate_proxy()が関与します。
  2. x + 1 という演算が行われると、Tracerはこの演算を捕捉します。この演算の結果(x + 1の値)もまたプロキシとして表現され、この新しいプロキシの生成にもcreate_proxy()が呼び出されます。
  3. 最終的に、この一連の操作がGraphとして記録され、モデルの計算の流れが抽象的に表現されます。


ここでは、torch.fx.Tracer.create_proxy()に関連する、またはその内部で発生しやすい一般的なエラーと、それらのトラブルシューティングについて説明します。

torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow (制御フローに関するエラー)

エラーの原因
FXは、Pythonの動的な制御フロー(if文、forループなど)を完全にトレースすることができません。Tracerは、プロキシオブジェクトがこれらの制御フローの条件(例えば、if proxy_tensor > 0:proxy_tensor > 0 の部分)として使用された場合にこのエラーを発生させます。プロキシは具体的な値を持たないため、条件分岐のパスを決定できないためです。

トラブルシューティング

  • torch.jit.scriptの検討
    もし複雑な制御フローが必要で、FXで表現するのが難しい場合、PyTorchのスクリプトモード(torch.jit.script)が代替手段として考えられます。スクリプトモードは、PythonコードをJITコンパイルして、より幅広いPython構文をサポートします。

    • モデルのforwardメソッド内で、入力テンソルの値に依存するような動的なif文やforループを使用しないようにモデルを書き換えます。
    • 例えば、テンソルの形状に依存するループなど、トレース時に形状が固定される場合は、それを定数に置き換えることを検討します。
    • 動的な制御フローが必要な場合は、その部分をis_leaf_moduleでマークするか、トレースの範囲外に移動させることを検討します。

TypeError または AttributeError (未対応のPythonオブジェクト/操作)

エラーの原因
TracerはPyTorchのテンソル操作やnn.Moduleの呼び出しを記録するように設計されています。Pythonのネイティブなデータ型(リスト、辞書、数値など)や、PyTorchのテンソルに直接関係しないPythonオブジェクトに対する操作は、デフォルトでは適切にプロキシ化されず、エラーが発生することがあります。特に、プロキシオブジェクトに対して、実際のテンソルに適用されるべきではないPythonのメソッドを呼び出したり、型が期待されるものではない場合に発生します。

トラブルシューティング

  • カスタムのTracerの利用
    特定の非テンソル操作やカスタムオブジェクトのトレースをサポートする必要がある場合、torch.fx.Tracerをサブクラス化し、create_proxy()や他のメソッドをオーバーライドして、カスタムロジックを追加することができます。ただし、これは高度な使い方であり、FXの内部構造を理解している必要があります。

  • 未対応のライブラリ/操作

    • numpyのようなPyTorchとは異なるライブラリの関数をモデル内で直接呼び出すと、FXはそれを解釈できません。
    • 可能な限りPyTorchのAPIを使用するようにコードを書き換えます。
    • どうしても外部ライブラリが必要な場合は、その部分をサブリとして扱い、is_leaf_moduleなどでトレースから除外することを検討します。
  • Pythonの組み込み型との混同

    • 例えば、Proxyオブジェクトをリストやタプルのように直接イテレートしようとするとエラーになります (Proxy object cannot be iterated)。
    • Proxyオブジェクトを数値のように直接比較したり、ハッシュ値を取得しようとするとエラーになることがあります。
    • これらの操作が必要な場合は、トレースの対象外とするか、テンソル操作として表現できるか検討します。

RuntimeError: "aten::..." is not an aten function (アテン関数が見つからない)

エラーの原因
これはcreate_proxy()自体が直接引き起こすエラーというよりは、TracerがPyTorchのオペレーション(通常はaten名前空間の関数)を記録しようとしたときに、認識できない関数が呼び出された場合に発生します。これは、内部的なAPIの使用や、PyTorchのバージョン間の互換性の問題で発生することがあります。

トラブルシューティング

  • 非推奨/内部APIの回避
    公式に公開されていない、または非推奨のPyTorchの内部APIを使用している可能性があります。できる限り公開されているAPIを使用するようにコードを書き換えます。
  • PyTorchのバージョン確認
    使用しているPyTorchのバージョンが古い場合や、実験的な機能を使用している場合に発生することがあります。最新の安定版PyTorchを使用しているか確認します。

トラブルシューティング

  • 最小限の再現可能なコードの作成
    問題を切り分けるために、エラーを再現する最小限のコードスニペットを作成します。これにより、問題の特定と解決が容易になります。
  • PyTorch FXのドキュメントとGitHub Issueの参照
    発生したエラーメッセージをPyTorchの公式ドキュメントやGitHubのFX関連のIssueで検索すると、同じ問題に遭遇した他の開発者の解決策が見つかることがあります。
  • 入出力の確認
    モデルのforwardメソッドの入力や、途中のテンソルの形状、データ型などが期待通りになっているか確認します。FXはシンボリックトレースなので、実際のテンソル値が問題になることは少ないですが、形状の不整合などが問題になることはあります。
  • モデルの単純化
    複雑なモデルでエラーが発生した場合、まずはモデルを可能な限り単純化して、どの部分でエラーが発生しているかを特定します。
  • torch.fx.symbolic_traceの利用
    通常はTracerを直接インスタンス化するよりも、高レベルなtorch.fx.symbolic_trace()関数を使用することが推奨されます。これにより、一般的なトレースのユースケースが簡素化されます。
  • print文の活用
    デバッグのために、モデルのforwardメソッド内でprint文を使って、プロキシオブジェクトがどのような形で伝播しているかを確認することができます。ただし、Proxyオブジェクト自体は直接プリントしても意味のある値は表示されないことが多いので、例えばisinstance(x, torch.fx.Proxy)などで型を確認するなど、工夫が必要です。
  • FXのトレース限界の理解
    FXは、Pythonのサブセットをトレースするように設計されています。特に、データに依存する制御フローや、PyTorchのテンソル操作に直接関係しないPythonの組み込み型やオブジェクトの操作は、トレースが難しいか、サポートされていません。FXのトレースの限界を理解することが重要です。


しかし、create_proxy()がどのように機能するかを理解するため、そしてFXのトレーシングメカニズムをより深く探求したい場合のために、カスタムのTracerを実装し、その中でcreate_proxy()が呼び出される状況を示す例を挙げることができます。

重要な注意点
以下の例は、create_proxy()の動作をデモンストレーションするためのものであり、一般的なPyTorch FXの使用方法を示すものではありません。通常は、このようにカスタムTracerを実装する必要はありません。

例1: カスタムTracerでcreate_proxy()の呼び出しを監視する

この例では、Tracerをサブクラス化し、create_proxy()メソッドをオーバーライドして、その呼び出しをログに記録します。これにより、トレース中にいつプロキシが生成されるかを確認できます。

import torch
import torch.nn as nn
from torch.fx.api import Tracer
from torch.fx.proxy import Proxy

# カスタムTracerの定義
class MyCustomTracer(Tracer):
    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None):
        """
        create_proxy メソッドをオーバーライドし、呼び出しをログに記録する
        """
        print(f"--- create_proxy called ---")
        print(f"  kind: {kind}")      # 'placeholder', 'call_function', 'call_method', 'call_module', 'get_attr'
        print(f"  target: {target}")  # 呼び出される関数、メソッド、モジュール、属性など
        print(f"  args: {args}")      # 引数 (Proxyオブジェクトを含む)
        print(f"  kwargs: {kwargs}")  # キーワード引数
        print(f"  name: {name}")      # 生成されるノードの名前 (自動生成されることが多い)
        print(f"-------------------------")

        # 基底クラスのcreate_proxyを呼び出して、実際のプロキシを生成させる
        return super().create_proxy(kind, target, args, kwargs, name, type_expr)

# シンプルなPyTorchモデル
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        # 1. 最初の入力 'x' (placeholder) のプロキシ
        # 2. self.linear モジュールの呼び出し
        x = self.linear(x)
        # 3. torch.relu 関数の呼び出し
        x = torch.relu(x)
        # 4. x + 1 の演算 (add オペレーション)
        return x + 1

# カスタムTracerを使ってモデルをトレース
tracer = MyCustomTracer()
graph = tracer.trace(SimpleModel())

print("\n--- Generated Graph ---")
graph.print_tabular()

print("\n--- Python Code from Graph ---")
print(graph.to_code())

コードの解説

  1. MyCustomTracer(Tracer): torch.fx.Tracerを継承したカスタムトレーサーを定義します。
  2. create_proxy(...)のオーバーライド:
    • このメソッド内で、create_proxy()が呼び出された際の引数(kind, target, args, kwargsなど)をプリントしています。
    • kindは、プロキシが何を表しているかを示します。
      • 'placeholder': モデルのforwardメソッドの入力引数。
      • 'call_module': nn.Moduleの呼び出し(例: self.linear(x))。
      • 'call_function': torch.relu()のようなPythonの関数呼び出し。
      • 'call_method': テンソルオブジェクトのメソッド呼び出し(例: x.sum())。
      • 'get_attr': モジュールの属性へのアクセス(例: self.some_param)。
    • 最後にsuper().create_proxy(...)を呼び出すことで、基底クラスの実際のプロキシ生成ロジックを実行させます。
  3. SimpleModel: 線形層とReLU、加算を含むシンプルなモデルです。
  4. tracer.trace(SimpleModel()): 定義したカスタムトレーサーのインスタンスを使って、モデルをトレースします。このトレース中に、SimpleModelforwardメソッドがシンボリックに実行され、その過程でMyCustomTracercreate_proxy()メソッドが複数回呼び出されます。

実行結果(抜粋)

--- create_proxy called ---
  kind: placeholder
  target: x
  args: ()
  kwargs: {}
  name: x
-------------------------
--- create_proxy called ---
  kind: call_module
  target: linear
  args: (<Proxy obj at 0x...>,) # ここで 'x' のプロキシが引数として渡されている
  kwargs: {}
  name: linear
-------------------------
--- create_proxy called ---
  kind: call_function
  target: <built-in function relu> # torch.relu
  args: (<Proxy obj at 0x...>,) # linearの出力のプロキシが引数
  kwargs: {}
  name: relu
-------------------------
--- create_proxy called ---
  kind: call_function
  target: <built-in function add> # x + 1 の '+' 演算
  args: (<Proxy obj at 0x...>, 1) # reluの出力のプロキシが引数
  kwargs: {}
  name: add
-------------------------

--- Generated Graph ---
opcode         name         target         args      kwargs
-------------  -----------  -------------  --------  --------
placeholder    x            x              ()        {}
call_module    linear       linear         (x,)      {}
call_function  relu         <built-in ...  (linear,) {}
call_function  add          <built-in ...  (relu, 1) {}
output         output       output         (add,)    {}

--- Python Code from Graph ---
def forward(self, x):
    linear = self.linear(x);  
    relu = torch.relu(linear);  
    add = relu + 1;  
    return add

この出力から、xというプレースホルダー、linearモジュールの呼び出し、relu関数の呼び出し、そしてadd+演算子)の呼び出しそれぞれで、新しいプロキシが生成され、その情報がcreate_proxy()を通じて記録されていることがわかります。

create_proxy()type_expr引数を持ち、生成されるノードの型ヒントを指定できます。これは通常、FXが自動的に推論しますが、カスタムで型を明示したい場合に利用できます。

import torch
import torch.nn as nn
from torch.fx.api import Tracer
from torch.fx.proxy import Proxy
from typing import Tuple # 型ヒントのためにインポート

class MyTypingTracer(Tracer):
    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None):
        # ここでは、特定の操作の型ヒントを変更する例を示す
        if target is torch.add: # torch.add 演算の結果を特定の型に指定
            # 例えば、タプルを返すかのように見せかける (あくまで型ヒントの例)
            print(f"Overriding type_expr for {target} to Tuple[torch.Tensor, torch.Tensor]")
            type_expr = Tuple[torch.Tensor, torch.Tensor] # 実際は単一テンソルだが、型ヒントを上書き
        
        # それ以外はデフォルトの動作
        return super().create_proxy(kind, target, args, kwargs, name, type_expr)

class ModelWithAdd(nn.Module):
    def forward(self, x, y):
        # x と y は両方ともテンソルを想定
        return x + y

tracer = MyTypingTracer()
graph = tracer.trace(ModelWithAdd())

print("\n--- Generated Graph with Type Hint ---")
graph.print_tabular()

# 生成されたコードを見て、型ヒントがどう影響するかを確認する
# (ただし、to_code()が型ヒントを明示的に出力しない場合もある)
print("\n--- Python Code from Graph ---")
print(graph.to_code())

# ノードの型情報を確認する (GraphIteratorを使って)
print("\n--- Node Type Information ---")
for node in graph.nodes:
    print(f"Node: {node.name}, Opcode: {node.op}, Target: {node.target}, Type: {node.type}")

コードの解説

  1. MyTypingTracerでは、torch.addターゲットの場合に、type_exprTuple[torch.Tensor, torch.Tensor]に強制的に設定しています。
  2. ModelWithAddは単純な加算を行います。
  3. トレース後、生成されたグラフのノードのtype属性をチェックすることで、create_proxy()で設定した型ヒントが反映されているかを確認できます。

この例は、create_proxy()が生成するプロキシ(ひいてはグラフのノード)のメタデータに影響を与えることができることを示しています。これは、FXグラフに対して静的解析ツールを適用したり、特定の最適化パスを実装したりする際に役立つ可能性があります。



そのため、「create_proxy() に代わるプログラミング方法」というよりは、「FX でのプロキシの生成とグラフ構築を制御するための、より高レベルな代替手段や、FX を使わずに同様の目的を達成する方法」として説明するのが適切です。

torch.fx.symbolic_trace() の利用 (最も一般的で推奨される方法)

torch.fx.Tracer.create_proxy() が提供する機能のほとんどは、torch.fx.symbolic_trace() を使うことで、よりシンプルかつ安全に利用できます。これが FX の標準的なインターフェースです。

  • なぜ代替となるか?: create_proxy() の低レベルな詳細を意識することなく、モデルの計算グラフを抽出できます。ほとんどのユースケースではこれで十分です。
  • 何をするか?: モデルの forward メソッドをシンボリックに実行し、PyTorch のオペレーションを FX の Graph として記録します。この過程で、Tracer が内部的に create_proxy() を呼び出して、各中間結果のプロキシを生成し、グラフノードとして追加します。


import torch
import torch.nn as nn
import torch.fx

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return torch.relu(self.linear(x))

# symbolic_trace を使ってモデルをトレースするだけ
traced_model = torch.fx.symbolic_trace(MyModel())

print("--- Generated Graph ---")
traced_model.graph.print_tabular()

print("\n--- Python Code ---")
print(traced_model.code)

torch.fx.Tracer のカスタムサブクラス化 (より高度な制御が必要な場合)

create_proxy() を直接変更する唯一の現実的な代替手段は、torch.fx.Tracer をサブクラス化し、その中の他のメソッドをオーバーライドすることです。これにより、トレースプロセスをカスタマイズできます。

  • なぜ代替となるか?: create_proxy() 自体を直接変更するよりも、トレースの特定の側面(例えば、特定のモジュールの扱い方や、特定の関数のグラフへの記録方法)を制御したい場合に適しています。

  • 何をするか?:

    • is_leaf_module(): 特定のモジュールが「葉」ノードとして扱われるべきか(つまり、その内部はトレースされない)を定義します。これにより、トレースの粒度を制御できます。
    • call_module(), call_function(), call_method(): これらのメソッドをオーバーライドすることで、特定のモジュール呼び出し、関数呼び出し、メソッド呼び出しがグラフにどのように記録されるかをカスタマイズできます。例えば、特定の呼び出しの挙動を変更したり、追加の情報をノードに付与したりできます。
    • create_proxy() を直接オーバーライドするのではなく、これらのメソッド内で create_proxy() が呼び出される前に処理を挿入することで、間接的にプロキシ生成に影響を与えます。


import torch
import torch.nn as nn
from torch.fx import Tracer, GraphModule, Proxy, Node

class MyCustomTracer(Tracer):
    def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
        # nn.Linear モジュールは、内部をトレースせず、単一のノードとして扱う
        if isinstance(m, nn.Linear):
            return True
        return super().is_leaf_module(m, module_qualified_name)

    # (オプション) 特定の関数呼び出しをカスタマイズする例
    def call_function(self, fn, args, kwargs):
        # torch.relu の呼び出しに特定のメタデータを追加したい場合など
        if fn is torch.relu:
            print(f"Customizing call for torch.relu with args: {args}")
            # ここでカスタムの処理を行い、その後で基底クラスのメソッドを呼び出す
            proxy = super().call_function(fn, args, kwargs)
            # 例: プロキシのノードにカスタムのアトリビュートを追加 (これはFXのノードのAPIに依存)
            # proxy.node.meta['custom_tag'] = 'relu_activated' # Node.meta は辞書としてアクセス可能
            return proxy
        return super().call_function(fn, args, kwargs)

class AnotherModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, padding=1)
        self.linear = nn.Linear(16*10*10, 10) # 仮のサイズ

    def forward(self, x):
        # Conv層はトレースされ、Linear層はリーフモジュールとして扱われる
        x = self.conv(x)
        x = torch.relu(x)
        # Flatten は PyTorch 2.x の symbolic_trace で自動的に処理されることが多い
        x = x.view(x.size(0), -1) 
        x = self.linear(x)
        return x

tracer = MyCustomTracer()
# ダミーの入力を用意
dummy_input = torch.randn(1, 3, 10, 10)
graph = tracer.trace(AnotherModel(), dummy_input)
traced_module = GraphModule(tracer.root, graph)

print("\n--- Generated Graph with Custom Tracer ---")
traced_module.graph.print_tabular()
print("\n--- Python Code from Custom Tracer ---")
print(traced_module.code)

JIT Tracing (torch.jit.trace) の利用 (FXとは異なる目的)

FX とは異なり、torch.jit.trace は実際のデータフローに基づいて PyTorch の Script モードの Graph を構築します。これは計算グラフをキャプチャしますが、FX のようなグラフ変換やメタプログラミングには向いていません。

  • なぜ代替となるか?:
    • 利点: FX が苦手とするデータ依存の制御フロー(if input.sum() > 0: のような)を扱うことができますが、トレース時に実行されたパスのみが記録されます。また、モデルのデプロイメント(C++など)に適しています。
    • 欠点: FX のように、モデルの様々な入力形状に対応できる汎用的なグラフを生成したり、グラフに対して複雑な変換を行うことには適していません。また、プロキシオブジェクトを直接操作するような柔軟性はありません。
  • 何をするか?: 実際の入力データを使ってモデルを実行し、実行されたオペレーションを記録します。これにより、JIT コンパイルされたモデル(TorchScript)が生成されます。


import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.RNN(10, 20, 2)

    def forward(self, input_seq):
        output, hn = self.rnn(input_seq)
        return output

# JIT tracing を使用
input_seq = torch.randn(5, 3, 10) # seq_len, batch_size, input_size
traced_rnn = torch.jit.trace(SimpleRNN(), input_seq)

print("\n--- JIT Traced Graph (TorchScript) ---")
# TorchScript のグラフは FX とは異なる形式
print(traced_rnn.graph) 

究極の代替手段は、torch.fx.Graphtorch.fx.Node オブジェクトを直接操作して、手動で計算グラフを構築することです。これは非常に低レベルであり、ほとんどのユースケースでは不要です。

  • なぜ代替となるか?:
    • 利点: 理論上はあらゆる種類のグラフを構築できます。シンボリックトレースではキャプチャできない非常に特殊なロジックを表現する必要がある場合に限られます。
    • 欠点: 非常に冗長でエラーを起こしやすいです。モデルの複雑さに応じて、コードの量が爆発的に増えます。デバッグも困難です。
  • 何をするか?: Graph オブジェクトをインスタンス化し、node.append()node.create_arg() などのメソッドを使って、個々のオペレーションを表す Node を追加していきます。create_proxy() が内部で行っているプロキシ生成とノード追加のプロセスを、すべて手動で行います。


import torch
from torch.fx import Graph, Node, Proxy, GraphModule

# 新しいグラフを初期化
g = Graph()

# 入力プレースホルダーノードを作成 (これは create_proxy('placeholder', ...) に相当)
x = g.placeholder('x')

# 加算操作を作成 (これは create_proxy('call_function', torch.add, ...) に相当)
# node.call_function(target_function, args_tuple, kwargs_dict)
add_one = g.call_function(torch.add, (x, 1))

# 出力ノードを作成
g.output(add_one)

# Graph を GraphModule に変換して実行可能にする
manual_model = GraphModule({}, g) # 空のNN.Module辞書でOK、モジュールは含まないため

print("\n--- Manually Constructed Graph ---")
manual_model.graph.print_tabular()

print("\n--- Python Code from Manual Graph ---")
print(manual_model.code)

# 実行例
dummy_input = torch.tensor(5)
output = manual_model(dummy_input)
print(f"\nManual model output for {dummy_input}: {output}") # 5 + 1 = 6

torch.fx.Tracer.create_proxy() の代替手段を考える場合、ほとんどのケースではtorch.fx.symbolic_trace() を使用することが推奨されます。これは、FX の強力な機能を安全かつ効率的に利用するための主要な方法です。