【初心者向け】PyTorch FXのon_generate_code():基本から応用までわかりやすく解説

2025-05-31

torch.fx.Graph.on_generate_code() は、torch.fx.Graph クラスに定義されているメソッドの一つで、グラフからPythonソースコードを生成する処理をカスタマイズするために用いられます。

FX (PyTorch Function eXchange) は、PyTorchモデルを中間表現(Graph)として捉え、そのグラフに対して様々な変換や最適化を行うためのフレームワークです。最終的に、このグラフから実行可能なPythonコードを生成する必要があります。on_generate_code() は、このコード生成の特定の段階にフックし、ユーザーが独自の処理を挿入できるようにするための仕組みです。

具体的には、torch.fx.GraphModule のインスタンスに対して code プロパティや forward メソッドが初めてアクセスされる際に、FXは内部的にグラフからPythonコードを生成します。このコード生成プロセスの中で、on_generate_code() メソッドが定義されていれば、それが呼び出されます。

このメソッドは、以下の点で役立ちます。

  • 特殊な処理の挿入
    コード生成のタイミングで、特定の条件に基づいて異なるコードを生成したり、ログ出力などの処理を挟んだりすることができます。
  • コードの整形や装飾
    生成されるコードに対して、インデントの調整、コメントの追加、特定のライブラリのインポートなどを自動的に行うことができます。
  • カスタムノードのコード生成
    FXがデフォルトでサポートしていないような独自の演算(カスタムオペレータなど)をグラフに含めている場合、on_generate_code() を実装することで、それらのノードに対応するPythonコードを生成する方法をFXに指示できます。

on_generate_code() メソッドは、通常、torch.fx.Graph クラスを継承したカスタムクラスの中で定義されます。このメソッドは引数として現在の torch.fx.Graph オブジェクトを受け取り、何も返しません。メソッド内でグラフの状態を直接変更したり、コード生成に必要な情報を内部的に保持したりすることができます。


import torch
import torch.nn as nn
from torch.fx import Graph, GraphModule

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 20)

    def forward(self, x):
        return self.linear(x).relu()

class CustomGraph(Graph):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.custom_code = []

    def on_generate_code(self):
        self.custom_code.append("# Custom code added during generation!")
        for node in self.nodes:
            if node.op == 'call_module' and node.target == 'linear':
                self.custom_code.append(f"# Found linear layer: {node.name}")

        def get_custom_code():
            return "\n".join(self.custom_code)

        self.get_custom_code = get_custom_code

gm = torch.fx.symbolic_trace(MyModule())
gm.graph.__class__ = CustomGraph

# コード生成をトリガー
code = gm.code
print(code)
print(gm.graph.get_custom_code())

この例では、CustomGraph クラスで on_generate_code() をオーバーライドしています。コード生成時に、特定のコメントを追加したり、linear レイヤーのノードを検出してメッセージを追加したりする処理を記述しています。生成されたコードや、get_custom_code() メソッドを通じて、このカスタム処理の結果を確認できます。



on_generate_code() が呼び出されない

  • トラブルシューティング
    • torch.fx.symbolic_trace などで GraphModule を作成し、その graph 属性のクラスをカスタム Graph クラスに設定しているか確認してください。
    • カスタム GraphModule を作成している場合は、その内部で graph 属性が正しくカスタム Graph のインスタンスになっているか確認してください。
    • コード生成をトリガーするために、GraphModulecode プロパティにアクセスするか、モデルとして使用する場合は forward メソッドを呼び出しているか確認してください。
  • 原因
    • torch.fx.Graph を直接操作しており、torch.fx.GraphModulecode プロパティや forward メソッドにアクセスしていない。on_generate_code() は、これらのプロパティやメソッドが初めてアクセスされる際に内部的にトリガーされます。
    • torch.fx.Graph のインスタンスのクラスをカスタムクラスに置き換えていない。on_generate_code() は、カスタム Graph クラス内でオーバーライドする必要があります。

生成されたコードが不正である

  • トラブルシューティング
    • on_generate_code() 内でグラフのノードを直接変更することは、予期せぬ副作用を引き起こす可能性があるため、慎重に行ってください。必要な場合は、新しいノードを作成してグラフに追加することを検討してください。
    • カスタムノードのコード生成では、そのノードの optargetargskwargs などの情報を正確に解析し、対応するPythonコードを生成するようにしてください。
    • 生成するコードの断片は、最終的に有効なPythonコードとして結合されることを意識し、インデントや改行などを適切に処理してください。生成されたコードを exec() などで実行してテストすることも有効です。
  • 原因
    • on_generate_code() 内でグラフのノード情報を誤って変更している。
    • カスタムノードに対するコード生成ロジックが正しく実装されていない。
    • 生成するコードのシンタックスがPythonの文法に違反している。

スコープと変数名の衝突

  • トラブルシューティング
    • FXが内部的に使用する変数名(例: _tensor_constant0, getattr_linear_weight)などを把握し、カスタムコードでこれらの名前との衝突を避けるようにしてください。
    • 必要であれば、よりユニークな変数名を生成するように工夫してください。
    • ノードの属性や他のノードの結果を参照する際には、FXが提供する仕組み(例: node.name)を利用して、正しいスコープでアクセスするようにしてください。
  • 原因
    • on_generate_code() で生成したコード内で使用している変数名が、FXが生成するデフォルトの変数名と衝突している。
    • カスタムノードのコード生成時に、正しいスコープで変数を参照できていない。

性能への影響

  • トラブルシューティング
    • on_generate_code() 内の処理は、必要な最小限に留めるようにしてください。
    • コード生成の結果をキャッシュするなどして、不要な再計算を避けることを検討してください。
  • 原因
    • on_generate_code() 内で複雑な処理を行っているため、コード生成に時間がかかっている。
    • コード生成のたびに不要な処理を繰り返している。

デバッグの困難さ

  • トラブルシューティング
    • on_generate_code() 内で生成するコードを段階的に構築し、その都度、生成されたコードを出力して確認するようにしてください。
    • ロギング機能を利用して、コード生成の過程で重要な情報を記録するようにしてください。
    • 単体テストを作成し、様々な入力に対して期待されるコードが生成されることを検証してください。
  • 原因
    • on_generate_code() で生成されるコードが複雑で、エラーが発生した場合に追跡が難しい。
    • コード生成の過程でどのようなコードが生成されているのかを確認する手段が限られている。
  • 小さな例から始める
    まずは簡単なグラフと on_generate_code() の実装から始め、徐々に複雑なケースに対応していくことで、問題を切り分けやすくなります。
  • FX のドキュメントとソースコードの参照
    FX の公式ドキュメントやソースコードを参照することで、内部の動作や設計思想を理解し、問題解決のヒントを得ることができます。
  • print デバッグ
    on_generate_code() 内で、グラフのノード情報や生成しようとしているコードの断片を print() 関数で出力して確認することは、基本的ながら非常に有効なデバッグ手法です。


例1: 簡単なコメントの追加

この例では、コード生成時に自動的にコメントを追加するシンプルなカスタム Graph クラスを作成します。

import torch
import torch.nn as nn
from torch.fx import Graph, GraphModule

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 20)

    def forward(self, x):
        return self.linear(x)

class CommentingGraph(Graph):
    def on_generate_code(self):
        self.code.append("# This code was generated by a custom Graph class.")

gm = torch.fx.symbolic_trace(MyModule())
gm.graph.__class__ = CommentingGraph

print(gm.code)

解説

  1. CommentingGraph クラスは torch.fx.Graph を継承しています。
  2. on_generate_code() メソッド内で、self.code.append() を使用して生成されるコードのリストにコメント文字列を追加しています。FX はこのリストの内容を結合して最終的なPythonコードを生成します。
  3. MyModulesymbolic_trace でトレースし、得られた GraphModulegraph 属性のクラスを CommentingGraph に置き換えます。
  4. gm.code にアクセスすると、コード生成がトリガーされ、on_generate_code() が実行され、生成されたコードの先頭にコメントが追加されていることが確認できます。

例2: 特定のノードに関する情報の埋め込み

この例では、グラフ内の特定の種類のノード(ここでは linear レイヤー)に関する情報を生成されるコードに埋め込みます。

import torch
import torch.nn as nn
from torch.fx import Graph, GraphModule

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        return self.linear2(x)

class NodeInfoGraph(Graph):
    def on_generate_code(self):
        for node in self.nodes:
            if node.op == 'call_module' and isinstance(self.get_submodule(node.target), nn.Linear):
                self.code.append(f"# Found a Linear layer: {node.name}")
                linear_module = self.get_submodule(node.target)
                self.code.append(f"#   Input features: {linear_module.in_features}, Output features: {linear_module.out_features}")

gm = torch.fx.symbolic_trace(MyModule())
gm.graph.__class__ = NodeInfoGraph

print(gm.code)

解説

  1. NodeInfoGraph クラスで on_generate_code() をオーバーライドしています。
  2. グラフ内の各ノードをイテレートし、op'call_module' であり、ターゲットのサブモジュールが nn.Linear のインスタンスであるかどうかをチェックしています。
  3. 該当するノードが見つかった場合、その名前と入力/出力の特徴数をコメントとして生成されるコードに追加しています。
  4. self.get_submodule(node.target) を使用して、ノードが参照している実際の nn.Module のインスタンスを取得し、その属性にアクセスしています。

例3: カスタムオペレータの基本的なサポート

FX がデフォルトでサポートしていないカスタムオペレータがある場合、on_generate_code() を使用してそのコード生成を制御できます(より複雑な場合は、register_node_handler の使用を検討します)。ここでは、単純な例として、my_custom_op という名前のオペレータに対するプレースホルダー的なコードを生成します。

import torch
import torch.nn as nn
from torch.fx import Graph, GraphModule

def my_custom_op(x, y):
    # 実際の処理はここでは定義しない
    return x + y

class ModuleWithCustomOp(nn.Module):
    def forward(self, a, b):
        return my_custom_op(a, b)

class CustomOpGraph(Graph):
    def on_generate_code(self):
        for node in self.nodes:
            if node.op == 'call_function' and node.target == my_custom_op:
                self.code.append(f"{node.name} = my_custom_op({', '.join(str(arg) for arg in node.args)})")
            # デフォルトのコード生成も行う必要がある場合は、super().on_generate_code() を呼び出すことを検討

gm = torch.fx.symbolic_trace(ModuleWithCustomOp())
gm.graph.__class__ = CustomOpGraph

print(gm.code)
  1. my_custom_op という単純な関数を定義します(実際には FX はこの関数の内部を知りません)。
  2. ModuleWithCustomOpmy_custom_opforward メソッド内で呼び出します。
  3. CustomOpGraphon_generate_code() では、op'call_function' であり、targetmy_custom_op であるノードを探します。
  4. 該当するノードが見つかった場合、そのノードの名前、オペレータ名、引数を使ってPythonコードの文字列を生成し、self.code に追加します。
  5. 注意
    この例は非常に基本的なものであり、実際のカスタムオペレータの統合には、引数の処理、戻り値の型、エラー処理など、より複雑なロジックが必要になる場合があります。より本格的なカスタムオペレータのサポートには、torch.fx.passes.graph_transformtorch.fx.node.register_node_handler の使用が推奨されます。


torch.fx.passes.graph_transform を使用したグラフ変換

torch.fx.passes.graph_transform は、グラフの中間表現(torch.fx.Graph)に対して変換を行うためのメカニズムです。on_generate_code() がコード生成直前のフックであるのに対し、グラフ変換はより早い段階でグラフのノードの追加、削除、変更などを行うことができます。


  • 特定のパターンを検出し、融合したノードに置き換える、不要なノードを削除する、カスタムオペレータを標準的なオペレータの組み合わせに分解する、など。
  • 利点
    • コード生成前にグラフ自体を操作するため、生成されるコードの構造や内容を根本的に変更できます。
    • 複数の変換をパイプラインとして適用できるため、複雑な処理を段階的に行うことができます。
    • グラフの最適化や特定パターンの置換など、より広範な目的に利用できます。
import torch
import torch.nn as nn
from torch.fx import GraphModule, symbolic_trace
from torch.fx.passes.graph_transform import GraphModuleTransformation

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 20)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.linear(x))

class FuseLinearReLU(GraphModuleTransformation):
    def pattern(self):
        class Pattern(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = None
                self.relu = None

            def forward(self, x):
                return self.relu(self.linear(x))
        return Pattern()

    def replacement(self, pattern_node):
        linear_node = pattern_node.args[0]
        relu_node = pattern_node

        # 新しい融合されたオペレータ(ここでは例として文字列)
        fused_node = self.graph.create_node(
            op='call_function',
            target='fused_linear_relu',
            args=(linear_node.args[0], linear_node.kwargs, relu_node.args[0]), # 引数の処理はより複雑になる場合あり
            kwargs={}
        )
        return fused_node

m = MyModule()
gm = symbolic_trace(m)
fuse_pass = FuseLinearReLU()
fused_gm = fuse_pass(gm)

print("Original Graph:")
print(gm.graph)
print("\nFused Graph:")
print(fused_gm.graph)

torch.fx.node.register_node_handler を使用したカスタムノードの処理

グラフ内に FX がデフォルトでサポートしていないカスタムオペレータや関数が含まれる場合、register_node_handler を使用して、それらのノードがどのようにコード生成されるかを定義できます。


  • 自作のC++拡張で定義されたオペレータに対するPythonバインディングの呼び出しコードを生成する、特定の外部ライブラリの関数呼び出しを生成するなど。
  • 利点
    • 特定の種類のノードに対して、生成されるPythonコードを細かく制御できます。
    • グラフ変換よりも、ノードの種類に基づいたより直接的なコード生成のカスタマイズが可能です。
import torch
import torch.nn as nn
from torch.fx import GraphModule, symbolic_trace
from torch.fx.node import Node, register_node_handler

def my_custom_function(a, b):
    return a * b

@register_node_handler(my_custom_function)
def my_custom_function_handler(node: Node, graph_module: GraphModule):
    return f"{node.name} = my_custom_function({', '.join(str(arg) for arg in node.args)})"

class MyModuleWithCustomFunc(nn.Module):
    def forward(self, x, y):
        return my_custom_function(x, y)

m = MyModuleWithCustomFunc()
gm = symbolic_trace(m)

print(gm.code)

torch.fx.interpreter.Interpreter のサブクラス化

torch.fx.Interpreter は、torch.fx.Graph を解釈して実行するための基盤となるクラスです。これをサブクラス化することで、グラフの実行方法をカスタマイズできます。コード生成とは少し異なりますが、グラフの構造に基づいて特定の処理を行いたい場合に利用できます。


  • グラフの実行中に特定の統計情報を収集する、特定のノードの出力を変更する、カスタムハードウェア上での実行をシミュレートするなど。
  • 利点
    • グラフの実行ロジックを直接制御できます。
    • 中間的な値を検査したり、特定のノードの実行をフックしたりできます。

テンプレートエンジンやコード生成ライブラリの利用

より複雑なコード生成のニーズがある場合、Jinja2 などのテンプレートエンジンや、ast モジュールなどのPythonのコード生成ライブラリを on_generate_code() 内で使用することも考えられます。


  • 特定のハードウェアアーキテクチャに最適化されたコードを生成する、複数のバックエンドに対応したコードを生成するなど。
  • 利点
    • 複雑なコード構造をより柔軟かつ可読性の高い方法で生成できます。
    • コードの再利用性や保守性が向上する可能性があります。
  • テンプレートエンジン/コード生成ライブラリ
    複雑なコード構造を生成する場合や、コードの可読性・保守性を高めたい場合に、on_generate_code() 内で利用することを検討できます。
  • インタプリタのサブクラス化
    グラフの実行方法をカスタマイズしたい場合に適しています(コード生成とはやや異なる目的)。
  • ノードハンドラ (register_node_handler)
    FX がデフォルトで扱えないカスタムオペレータや関数に対して、そのコード生成ロジックを定義するのに適しています。
  • グラフ変換 (graph_transform)
    グラフの構造自体を変更する必要がある場合、ノードの追加、削除、置換など、より根本的な操作に適しています。
  • on_generate_code()
    生成されるPythonコードの最終的な整形、コメントの追加、ごく簡単なコード断片の挿入など、コード生成の最終段階における調整に適しています。