torch.fx.Tracer.is_leaf_module()のプログラミング例:カスタムなグラフ構築

2025-05-31

このメソッドは、PyTorchのFXグラフのトレース処理において、与えられたtorch.nn.Moduleのインスタンス(m)が「リーフモジュール」として扱われるべきかどうかを判定するために使用されます。

「リーフモジュール」とは何か?

FXグラフのトレース処理では、PyTorchのモデルを構成する各モジュールをノードとしてグラフ化します。この際、あるモジュールがそれ以上深くトレースされず、その内部構造がFXグラフの個別のノードとして展開されない場合、そのモジュールは「リーフモジュール」と見なされます。

is_leaf_module() の役割

is_leaf_module() メソッドは、Tracer がモデルの各モジュールを訪れる際に呼び出されます。このメソッドの戻り値に基づいて、Tracer は以下のいずれかの処理を行います。

  • False を返した場合
    そのモジュールはリーフモジュールとは見なされず、Tracer はそのモジュール内部のサブモジュールや演算をさらに深くトレースし、FXグラフの個別のノードとして展開します。
  • True を返した場合
    そのモジュールはリーフモジュールとして扱われ、そのモジュール自体がFXグラフの単一のノードとして表現されます。そのモジュール内部の演算はトレースされません。

デフォルトの動作

デフォルトでは、Tracertorch.nn.Module のほとんどの基本的なモジュール(例:torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU など)をリーフモジュールとして扱います。これは、これらのモジュールの内部演算は比較的単純であり、個別のFXノードとして表現するよりも、モジュール全体を一つのノードとして扱う方がグラフの可読性や操作性が向上するためです。

qualified_name 引数

qualified_name 引数は、モデル内でのモジュールの階層的な名前(例:layer1.0.conv1)を表します。この名前を使って、特定のモジュールに対してリーフモジュールとしての振る舞いをカスタマイズすることができます。

利用場面

is_leaf_module() メソッドは、torch.fx.Tracer をサブクラス化し、その振る舞いをカスタマイズしたい場合に主に利用されます。例えば、以下のような場合に役立ちます。

  • デフォルトではリーフモジュールとして扱われるモジュールをさらに深くトレースしたい場合
    例えば、ある特定の torch.nn.Linear モジュールの内部演算を詳細に分析したい場合に、このメソッドをオーバーライドして False を返すように実装します。
  • 特定のカスタムモジュールをリーフモジュールとして扱いたい場合
    複雑な処理を行うカスタムモジュールがあり、その内部構造をFXグラフに含めたくない場合に、このメソッドをオーバーライドして True を返すように実装します。

import torch
import torch.nn as nn
from torch.fx import Tracer

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

class CustomTracer(Tracer):
    def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
        # MyModule 全体をリーフモジュールとして扱う
        if isinstance(m, MyModule):
            return True
        # それ以外のモジュールはデフォルトの挙動に従う
        return super().is_leaf_module(m, qualified_name)

model = MyModule()
tracer = CustomTracer()
graph = tracer.trace(model)
print(graph)

この例では、CustomTraceris_leaf_module() メソッドをオーバーライドして、MyModule のインスタンスが渡された場合に True を返すようにしています。これにより、生成されるFXグラフでは MyModule 全体が 하나의 노ードとして表現され、その内部の linear1, relu, linear2 は個別のノードとして展開されません。



一般的なエラー

  1. 型アノテーションの不一致
    is_leaf_module() メソッドの定義は (m: torch.nn.Module, qualified_name: str) -> bool です。オーバーライドしたメソッドの引数や戻り値の型がこれと一致しない場合、TypeErrorが発生する可能性があります。

    class CustomTracer(Tracer):
        def is_leaf_module(self, module, name): # 型アノテーションがない、または異なる型
            return isinstance(module, MyCustomModule)
    

    トラブルシューティング
    メソッドのシグネチャを正確に (m: torch.nn.Module, qualified_name: str) -> bool に合わせます。

  2. 論理的な誤りによる意図しないリーフ化
    is_leaf_module() の条件判定が誤っていると、本来トレースしたいモジュールがリーフモジュールとして扱われてしまい、FXグラフに必要な情報が含まれなくなることがあります。

    class CustomTracer(Tracer):
        def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
            # すべての線形層をリーフモジュールにしてしまう誤った例
            return isinstance(m, nn.Linear)
    

    トラブルシューティング
    条件判定が意図通りに動作しているか、さまざまなモジュールに対してテストを行い、FXグラフの生成結果を確認します。

  3. super().is_leaf_module() の呼び出し忘れ
    デフォルトのリーフモジュールとしての振る舞いを維持しつつ、特定のモジュールに対してカスタムのリーフ判定を行いたい場合、適切なタイミングで super().is_leaf_module(m, qualified_name) を呼び出す必要があります。これを忘れると、デフォルトのリーフ判定が機能しなくなります。

    class CustomTracer(Tracer):
        def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
            if isinstance(m, MyCustomModule):
                return True
            # super() の呼び出しを忘れている
            return False
    

    トラブルシューティング
    カスタムの条件に合致しない場合は、必ず super().is_leaf_module(m, qualified_name) を呼び出すようにします。

  4. qualified_name の誤用
    qualified_name はモジュールの階層的な名前であり、これを利用して特定のモジュールをリーフ化する際に、名前を間違えると意図したモジュールがリーフ化されません。

    class CustomTracer(Tracer):
        def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
            # モジュールの名前を間違えている
            return qualified_name == "my_module.sub_module.wrong_name"
    

    トラブルシューティング
    トレース対象のモデルの構造を把握し、qualified_name が正しいことを確認します。トレース中に qualified_name をログ出力するなどして確認するのも有効です。

トラブルシューティング

  1. FXグラフの確認
    Tracer でトレースした結果の torch.fx.Graph オブジェクトの内容を print(graph) などで確認し、意図したノード構成になっているかを確認します。リーフモジュールとして扱われるべきでないモジュールがリーフになっている、あるいはその逆のケースがないかを目視でチェックします。

  2. 中間層の出力の確認
    FXグラフを解釈・実行する際に、期待される中間層の出力が得られているかを確認します。意図しないリーフ化によって、必要な演算がグラフに含まれていない可能性があります。

  3. 簡単なモデルでのテスト
    複雑なモデルで問題が発生する場合は、よりシンプルなモデルを作成し、is_leaf_module() のカスタム実装が意図通りに動作するかどうかを個別にテストします。

  4. ログ出力の活用
    is_leaf_module() メソッド内で、引数 m の型や qualified_name の値、そして戻り値をログ出力するようにして、トレース処理中にどのように判定が行われているかを確認します。

    import logging
    logging.basicConfig(level=logging.INFO)
    
    class CustomTracer(Tracer):
        def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
            is_leaf = isinstance(m, MyCustomModule)
            logging.info(f"Module: {qualified_name}, Type: {type(m)}, is_leaf: {is_leaf}")
            return is_leaf
    
  5. PyTorch FX のドキュメントやチュートリアルを参照
    PyTorch FX の公式ドキュメントや関連するチュートリアルを再度確認し、is_leaf_module() の正しい使い方や、トレース処理の仕組みを理解を深めます。



例1: 特定のカスタムモジュールをリーフモジュールとして扱う

import torch
import torch.nn as nn
from torch.fx import Tracer

# カスタムモジュールの定義
class MyCustomOperation(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(5, 10))
        self.bias = nn.Parameter(torch.randn(5))

    def forward(self, x):
        return torch.matmul(x, self.weight.T) + self.bias

# モデルの定義
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.custom_op = MyCustomOperation()
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.custom_op(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

# カスタムトレーサーの定義
class CustomTracer(Tracer):
    def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
        # MyCustomOperation のインスタンスをリーフモジュールとして扱う
        if isinstance(m, MyCustomOperation):
            return True
        # それ以外のモジュールはデフォルトの挙動に従う
        return super().is_leaf_module(m, qualified_name)

# モデルのインスタンス化とトレース
model = MyModel()
tracer = CustomTracer()
graph = tracer.trace(model)

# 生成された FX グラフの表示
print(graph)

この例では、MyCustomOperation というカスタムモジュールを定義しています。CustomTraceris_leaf_module() メソッドをオーバーライドし、モジュールが MyCustomOperation のインスタンスである場合に True を返すようにしています。その結果、生成されるFXグラフでは MyCustomOperation の内部構造(weightbias パラメータ、matmul や加算演算)は展開されず、custom_op という一つのノードとして表現されます。

例2: 特定の名前を持つモジュールをリーフモジュールとして扱う

import torch
import torch.nn as nn
from torch.fx import Tracer

class SubModule(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.linear(x))

class MainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = SubModule(10, 20)
        self.layer2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class NameBasedLeafTracer(Tracer):
    def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
        # 'layer1' という名前のモジュールをリーフモジュールとして扱う
        if qualified_name == "layer1":
            return True
        return super().is_leaf_module(m, qualified_name)

model = MainModel()
tracer = NameBasedLeafTracer()
graph = tracer.trace(model)
print(graph)

この例では、NameBasedLeafTraceris_leaf_module() メソッドで、引数 qualified_name を利用しています。qualified_name"layer1" である場合(MainModelself.layer1)、その SubModule インスタンスをリーフモジュールとして扱います。したがって、FXグラフでは layer1 の内部構造(linearrelu)は展開されず、一つのノードとして表現されます。

例3: デフォルトのリーフモジュールをさらに深くトレースする

デフォルトでは nn.Linear などはリーフモジュールとして扱われますが、この振る舞いを変更して内部をトレースする例です。

import torch
import torch.nn as nn
from torch.fx import Tracer

class DeepTraceLinearTracer(Tracer):
    def is_leaf_module(self, m: nn.Module, qualified_name: str) -> bool:
        # nn.Linear の場合はリーフモジュールとしない (False を返す)
        if isinstance(m, nn.Linear):
            return False
        return super().is_leaf_module(m, qualified_name)

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.linear(x))

model = SimpleModel()
tracer = DeepTraceLinearTracer()
graph = tracer.trace(model)
print(graph)

この例では、DeepTraceLinearTraceris_leaf_module() メソッドで、モジュールが nn.Linear のインスタンスである場合に False を返しています。これにより、通常はリーフモジュールとして扱われる nn.Linear の内部演算(重みとバイアスの操作など)もFXグラフの個別のノードとしてトレースされます。生成されるグラフを見ると、linear モジュール内で行われる matmul や加算などの演算がノードとして現れることがわかります。



torch.fx.symbolic_trace() 関数の concrete_args 引数

torch.fx.symbolic_trace() 関数は、Tracer を内部で使用してモデルをトレースしますが、concrete_args 引数を使用することで、トレース時に特定のモジュールの forward メソッドに具体的な値を渡すことができます。これにより、そのモジュール内の一部の処理が定数畳み込みなどによってグラフから削除されたり、特定のパスのみがトレースされたりする可能性があります。これは、is_leaf_module() とは異なるアプローチですが、結果的にグラフの複雑さを軽減したり、特定のモジュールの内部を詳細に把握したりするのに役立つことがあります。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.should_skip = False

    def forward(self, x, skip: bool):
        self.should_skip = skip
        if skip:
            return x
        else:
            return self.linear(x)

model = MyModule()
# skip=True の場合のトレース
graph_skip_true = symbolic_trace(model, concrete_args={'skip': True})
print("skip=True の場合のグラフ:")
print(graph_skip_true)

# skip=False の場合のトレース
graph_skip_false = symbolic_trace(model, concrete_args={'skip': False, 'x': torch.randn(1, 10)})
print("\nskip=False の場合のグラフ:")
print(graph_skip_false)

この例では、concrete_args を用いて skip 引数の値を固定してトレースすることで、条件分岐によるグラフの変化を観察できます。

属性へのアクセス制御 (__getattr__ のオーバーライド)

モデルの属性へのアクセスを制御することで、Tracer が特定のサブモジュールを認識しないようにすることができます。これにより、そのサブモジュールはトレースされず、親モジュールの一部として扱われるようになります。ただし、これは Tracer の内部動作に深く関わるため、注意が必要です。

import torch
import torch.nn as nn
from torch.fx import Tracer

class InnerModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.randn(3))

    def forward(self, x):
        return x + self.param

class OuterModule(nn.Module):
    def __init__(self):
        super().__init__()
        self._inner = InnerModule() # _ で始まる属性は慣例的に内部使用を示す

    def forward(self, x):
        return self._inner(x) * 2

class HiddenInnerModuleTracer(Tracer):
    def getattr(self, obj, attr):
        if attr == '_inner' and isinstance(obj, OuterModule):
            # _inner 属性へのアクセスをフックして None を返す
            return None
        return super().getattr(obj, attr)

model = OuterModule()
tracer = HiddenInnerModuleTracer()
graph = tracer.trace(model)
print(graph)

この例では、HiddenInnerModuleTracergetattr メソッドをオーバーライドし、OuterModule_inner 属性へのアクセス時に None を返すようにしています。これにより、Tracer_inner を通常のサブモジュールとして認識せず、トレース対象から除外される可能性があります(ただし、FXのバージョンや内部実装によって挙動が変わる可能性があります)。

事前処理によるモデルの変更

トレースを行う前に、モデル自体を操作して、FXグラフに含めたくない部分を削除したり、より単純なモジュールに置き換えたりする方法です。例えば、複雑なカスタムモジュールを、等価なより基本的な演算の組み合わせに分解してからトレースするなどです。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace

class ComplexModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 5)

    def complicated_operation(self, x):
        return torch.relu(self.linear1(x)) + self.linear2(x)

    def forward(self, x):
        return self.complicated_operation(x)

class SimplifiedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(20, 5)

    def forward(self, x):
        return torch.relu(self.linear1(x)) + self.linear2(x)

original_model = ComplexModule()
simplified_model = SimplifiedModel() # ComplexModule と等価だが、トレースしやすい構造

graph_original = symbolic_trace(original_model)
print("元のモデルのグラフ:")
print(graph_original)

graph_simplified = symbolic_trace(simplified_model)
print("\n簡略化されたモデルのグラフ:")
print(graph_simplified)

この例では、ComplexModulecomplicated_operation を直接トレースする代わりに、等価な処理を行う SimplifiedModel を作成してトレースしています。

FXグラフの手動構築

最も低レベルな方法として、torch.fx.Graph オブジェクトを直接操作して、必要なノードを手動で追加していく方法があります。これは、既存のモデルをトレースするのではなく、完全に新しいグラフを生成する場合や、トレース結果を細かく編集する場合に用いられます。Tracercreate_node() メソッドなどを利用してノードを作成し、接続を定義します。

import torch
import torch.fx.graph as fx_graph
from torch.fx import Tracer

# 空のグラフを作成
graph = fx_graph.Graph()

# 入力ノードの作成
input_node = graph.create_node(op='placeholder', name='input', args=())

# 線形層のノード(手動でオペコードなどを指定)
linear_weight = graph.create_node(op='get_attr', name='linear_weight', args=('linear.weight',))
linear_bias = graph.create_node(op='get_attr', name='linear_bias', args=('linear.bias',))
linear_output = graph.create_node(op='call_function', target=torch.matmul, args=(input_node, linear_weight))
linear_output_biased = graph.create_node(op='call_function', target=torch.add, args=(linear_output, linear_bias))

# ReLU のノード
relu_output = graph.create_node(op='call_function', target=torch.relu, args=(linear_output_biased,))

# 出力ノードの作成
output_node = graph.create_node(op='output', name='output', args=(relu_output,))

graph.lint() # グラフの整合性をチェック
print(graph)

この例では、Tracer を使用せずに、torch.fx.Graph を直接操作して簡単な線形層と ReLU の処理を表すグラフを構築しています。