torch.fx.Graph.set_codegen()

2025-05-31

torch.fx は、PyTorch の nn.Module インスタンスを変換するためのツールキットであり、主に以下の3つの主要なコンポーネントで構成されています。

  1. シンボリックトレーサー (Symbolic Tracer): Python コードをシンボリックに実行し、操作を記録します。
  2. 中間表現 (Intermediate Representation - IR): シンボリックトレース中に記録された操作を格納する Graph オブジェクトです。この Graph に対して様々な変換が適用されます。
  3. Python コード生成 (Python Code Generation): Graph IR から有効な Python コードを生成します。

torch.fx.Graph.set_codegen() の役割は、この3番目のコンポーネント、つまりPython コード生成の動作をカスタマイズすることです。

通常、Graph から GraphModule (これは nn.Module のインスタンスであり、Graph から生成された forward メソッドを持つ) を作成する際、PyTorch はデフォルトのコード生成ロジックを使用します。しかし、set_codegen() を使用することで、このデフォルトの挙動を独自のコード生成器に置き換えることができます。

具体的な説明

set_codegen() メソッドは、torch.fx.codegen.CodeGen クラスのインスタンス、またはそのサブクラスのインスタンスを引数として受け取ります。CodeGen クラスは、Python コードを生成するためのいくつかのカスタマイズ可能なフックを提供します。

例えば、CodeGen クラスには以下のようなメソッドがあります。

  • additional_globals(): 生成されるコードが追加で使用するグローバル変数を提供します。
  • process_outputs(): グラフからの出力がどのように処理されるべきかを定義します。
  • process_inputs(): グラフへの入力がどのように処理されるべきかを定義します。
  • generate_prologue(): 生成される関数の冒頭部分(def forward(...) など)を生成します。

これらのメソッドをオーバーライドすることで、ユーザーは自分の要件に合わせてコード生成のプロセスを細かく制御できます。

set_codegen() を使うメリット

  • ドメイン固有の最適化: 特定のドメインやデバイスに特化した最適化のために、コード生成のロジックを変更することができます。
  • AOTAutogradなどの高度な機能: AOTAutograd のような、PyTorch の高度な機能の一部は、カスタムのコード生成器を利用して、特定の計算グラフの形式に対応したコードを生成します。
  • 柔軟なコード生成: 特定のバックエンド向けに最適化されたコードを生成したり、デバッグのために詳細な情報を出力するコードを生成したりするなど、標準のコード生成では対応できない柔軟なコード生成が可能になります。
import torch
import torch.fx
import torch.nn as nn
from torch.fx.codegen import CodeGen, PythonCode

class CustomCodeGen(CodeGen):
    def generate_prologue(self, free_vars, maybe_return_annotation):
        # カスタムのフォワードメソッドの冒頭部分を生成
        return f"""
def forward(self, *args){maybe_return_annotation}:
    # カスタムコードジェネレータからのコメント
    {', '.join(free_vars)} = args
"""

    def process_inputs(self, args):
        # 入力の処理をカスタマイズ
        return args

    def process_outputs(self, outputs):
        # 出力の処理をカスタマイズ
        return outputs

# 簡単なモジュール
class MyModule(nn.Module):
    def forward(self, x, y):
        return x + y * 2

# モジュールをトレースしてGraphを取得
m = MyModule()
graph = torch.fx.symbolic_trace(m).graph

# カスタムのコード生成器を設定
graph.set_codegen(CustomCodeGen())

# GraphModuleを再構築 (これにより、設定したcodegenが使用される)
# 通常、set_codegen()を呼び出した後、GraphModuleを再構築するか、
# gm.recompile() を呼び出して forward メソッドを再生成する必要があります。
gm = torch.fx.GraphModule(m, graph)

# 生成されたコードを確認
print(gm.code)

上記の例では、CustomCodeGen という独自のコード生成器を定義し、generate_prologue メソッドをオーバーライドして、生成される forward メソッドの冒頭部分にカスタムコメントを追加しています。このように、set_codegen() を使うことで、Graph から実際に生成される Python コードの見た目や動作を詳細に制御できるようになります。





set_codegen() は、PyTorch FX が Graph から nn.Moduleforward メソッドとして実行可能な Python コードを生成するプロセスをカスタマイズするために使用されます。

例1: 最も基本的なカスタム CodeGen (コメントの追加)

この例では、生成される forward メソッドにカスタムのコメントを追加するだけの、非常にシンプルな CodeGen を作成します。

import torch
import torch.nn as nn
import torch.fx
from torch.fx.codegen import CodeGen, PythonCode # PythonCodeはデフォルトのCodeGen

# 1. カスタムのCodeGenクラスを定義
class SimpleCommentCodeGen(CodeGen):
    # generate_prologue: 生成される関数の冒頭部分(def forward(...))をカスタマイズ
    def generate_prologue(self, free_vars, maybe_return_annotation):
        # 親クラス(PythonCode)のgenerate_prologueを呼び出して、基本的なプロローグを取得
        # ここで、free_vars は forward メソッドの引数となる変数名(例: ['x', 'y'])
        # maybe_return_annotation は戻り値の型ヒント(例: -> Tensor)
        original_prologue = super().generate_prologue(free_vars, maybe_return_annotation)
        
        # 取得したプロローグにカスタムコメントを追加
        return original_prologue + "\n    # --- Custom CodeGen: Start of forward method ---"

    # ここでは他のメソッド(process_inputs, process_outputs, additional_globals など)はオーバーライドせず、
    # デフォルトのPythonCodeの動作をそのまま利用します。

# 2. PyTorch モデルを定義
class MySimpleModule(nn.Module):
    def forward(self, x, y):
        # 簡単な演算
        a = x + y
        b = a * 2
        return b - x

# 3. モデルをシンボリックトレースしてGraphを取得
my_module = MySimpleModule()
# verbose=True にすると、トレース中のノードの詳細が表示される
traced_graph = torch.fx.symbolic_trace(my_module) 
graph = traced_graph.graph

# 4. GraphにカスタムのCodeGenを設定
graph.set_codegen(SimpleCommentCodeGen())

# 5. カスタムCodeGenが適用された新しいGraphModuleを作成
# set_codegen() を呼び出しただけでは、既存の traced_graph (GraphModule) の forward メソッドは更新されません。
# そのため、Graph から新しい GraphModule を再構築する必要があります。
# または、既存の GraphModule に対して gm.recompile() を呼び出すこともできます。
gm_custom_codegen = torch.fx.GraphModule(my_module, graph)

# 6. 生成されたコードを確認
print("--- Generated Code with Custom Comment ---")
print(gm_custom_codegen.code)

# 7. 動作確認
input_x = torch.randn(5)
input_y = torch.randn(5)

output_original = my_module(input_x, input_y)
output_custom = gm_custom_codegen(input_x, input_y)

print(f"\nOriginal module output: {output_original}")
print(f"Custom CodeGen module output: {output_custom}")
assert torch.allclose(output_original, output_custom)
print("Outputs match!")

解説:

  • 重要な点: GraphModuleGraph を元に forward メソッドを生成します。set_codegen() を呼び出した後に、その変更を反映させるには GraphModule(original_module, graph) のように再構築するか、既存の GraphModule に対して gm.recompile() を呼び出す必要があります。
  • graph.set_codegen(SimpleCommentCodeGen()) でこのカスタム CodeGenGraph に設定します。
  • generate_prologue メソッドをオーバーライドし、super().generate_prologue(...) でデフォルトのプロローグを取得した後、独自のコメントを追加しています。
  • SimpleCommentCodeGenCodeGen クラスを継承しています。

例2: 独自のグローバル変数を生成コードに含める

この例では、生成されるコード内で使用したいグローバルな定数や関数を additional_globals() メソッドを通じて提供する方法を示します。

import torch
import torch.nn as nn
import torch.fx
from torch.fx.codegen import CodeGen, PythonCode
import math # 生成コードでmath.piを使用したい

# 1. カスタムのCodeGenクラスを定義
class GlobalConstantCodeGen(CodeGen):
    # additional_globals: 生成コードからアクセスできるグローバルなオブジェクトを定義
    def additional_globals(self):
        # デフォルトのグローバル変数に加えて、mathモジュールとカスタム定数を追加
        # ここで定義したキー(文字列)が、生成コード内で変数名として使われます。
        return {
            **super().additional_globals(), # デフォルトのグローバル変数も引き継ぐ
            "math": math,                 # mathモジュール全体をグローバルとして提供
            "CUSTOM_FACTOR": 10.0         # カスタムの定数
        }

# 2. PyTorch モデルを定義
class MyGlobalModule(nn.Module):
    def forward(self, x):
        # math.pi や CUSTOM_FACTOR を直接使った演算はできません。
        # FXはシンボリックトレースなので、Pythonの組み込み関数は追跡できません。
        # しかし、生成されたコードが後でこれらを使うように、CodeGenを工夫することはできます。

        # ここでは、ただxを返すだけのシンプルなモデルにする
        # 実際に生成されたコードがこれらのグローバル変数を利用するようにします
        return x * 2

# 3. モデルをシンボリックトレースしてGraphを取得
my_global_module = MyGlobalModule()
traced_graph_global = torch.fx.symbolic_trace(my_global_module)
graph_global = traced_graph_global.graph

# 4. GraphにカスタムのCodeGenを設定
graph_global.set_codegen(GlobalConstantCodeGen())

# 5. カスタムCodeGenが適用された新しいGraphModuleを作成
gm_global_codegen = torch.fx.GraphModule(my_global_module, graph_global)

# 6. 生成されたコードを確認
print("\n--- Generated Code with Custom Globals ---")
print(gm_global_codegen.code)

# 生成されたコードでは、実際には 'math' や 'CUSTOM_FACTOR' が使われていないかもしれません。
# それらを強制的に使うようにするには、CodeGenの他のメソッド(例えば、Graphのノードを変換する部分)を
# もっと複雑にオーバーライドする必要があります。
# この例は、あくまで global 変数をコード生成器に渡す方法を示しています。

# 動作確認 (コード自体は元のままなので、出力はMyGlobalModuleと同じ)
input_tensor = torch.randn(3)
output_original_global = my_global_module(input_tensor)
output_custom_global = gm_global_codegen(input_tensor)

print(f"\nOriginal module output (global): {output_original_global}")
print(f"Custom CodeGen module output (global): {output_custom_global}")
assert torch.allclose(output_original_global, output_custom_global)
print("Outputs match!")

# 生成されたコード内に 'math' や 'CUSTOM_FACTOR' が含まれていることを確認
# ここで直接コードを修正するわけではないので、あくまでadditional_globalsが使われたことを示す。
# print(gm_global_codegen.code) の出力を見てください。
# 例えば、もしコード内で torch.add や torch.mul が使われる場合、これらの関数は
# additional_globals() を通じて提供される 'torch' オブジェクトから解決されます。

解説:

  • この例では、math モジュールと CUSTOM_FACTOR というカスタム定数を提供しています。
  • super().additional_globals() を呼び出すことで、PyTorch がデフォルトで提供する torchF (torch.nn.functional) などのグローバル変数も引き継いでいます。
  • このメソッドは辞書を返し、その辞書に含まれるキーと値が、生成される forward メソッドのスコープ内で利用可能なグローバル変数として扱われます。
  • GlobalConstantCodeGen では additional_globals() メソッドをオーバーライドしています。

この例はより高度で、Graph 内の特定の操作 (Node) がどのように Python コードに変換されるかを変更します。ここでは、torch.add (加算) を torch.sub (減算) に置き換える極端な例を示します。

import torch
import torch.nn as nn
import torch.fx
from torch.fx.codegen import CodeGen, PythonCode
from torch.fx.node import Node, map_arg # Nodeとmap_argはノード処理に便利

# 1. カスタムのCodeGenクラスを定義
class CustomOperationCodeGen(PythonCode): # PythonCodeを継承すると便利
    def __init__(self, graph: torch.fx.Graph):
        super().__init__()
        self.graph = graph # Graphオブジェクトを保持

    # emit_node: 各ノードに対応するPythonコードの行を生成
    # これはPythonCodeクラスの内部メソッドであり、通常は直接オーバーライドしませんが、
    # ここでは挙動を変更するために特別にオーバーライドします。
    # より一般的な方法は、Nodeの_codegen_spec_ を使うことですが、これはより直接的です。
    def emit_node(self, node: Node) -> str:
        # call_function ノードの場合をチェック
        if node.op == 'call_function':
            if node.target == torch.add:
                # torch.add の代わりに torch.sub を呼び出すコードを生成
                # map_arg はノードの引数を、生成コードの変数名に変換するために使用
                args_str = ', '.join(map(str, map_arg(node.args, lambda n: self.current_proxy_codegen.var_map[n])))
                kwargs_str = ', '.join(f'{k}={map_arg(v, lambda n: self.current_proxy_codegen.var_map[n])}' for k, v in node.kwargs.items())
                
                # 生成されるコード行
                return f"{self.current_proxy_codegen.var_map[node]} = torch.sub({args_str}{', ' if args_str and kwargs_str else ''}{kwargs_str})"
            # その他のcall_functionはデフォルトの動作
            # else:
            #     return super().emit_node(node) # 親クラスのメソッドを呼び出す

        # その他のノードタイプ(call_module, get_attrなど)はデフォルトの動作を使用
        return super().emit_node(node)

# 2. PyTorch モデルを定義
class MyAddModule(nn.Module):
    def forward(self, x, y):
        # x + y を行うモジュール
        return x + y

# 3. モデルをシンボリックトレースしてGraphを取得
my_add_module = MyAddModule()
traced_graph_add = torch.fx.symbolic_trace(my_add_module)
graph_add = traced_graph_add.graph

# 4. GraphにカスタムのCodeGenを設定
# CustomOperationCodeGen は Graph オブジェクトを必要とするので、ここで渡す
graph_add.set_codegen(CustomOperationCodeGen(graph_add))

# 5. カスタムCodeGenが適用された新しいGraphModuleを作成
gm_custom_op = torch.fx.GraphModule(my_add_module, graph_add)

# 6. 生成されたコードを確認
print("\n--- Generated Code with Operation Change (Add -> Sub) ---")
print(gm_custom_op.code)

# 7. 動作確認
input_x_op = torch.tensor([10.0])
input_y_op = torch.tensor([5.0])

# 元のモデルは加算
output_original_op = my_add_module(input_x_op, input_y_op) # 10 + 5 = 15
print(f"\nOriginal module output (10 + 5): {output_original_op}")

# カスタムCodeGen適用モジュールは減算になるはず
output_custom_op = gm_custom_op(input_x_op, input_y_op) # 10 - 5 = 5
print(f"Custom CodeGen module output (10 - 5): {output_custom_op}")

# 出力が異なることを確認 (意図的に異なる結果を出す例)
assert not torch.allclose(output_original_op, output_custom_op)
assert torch.allclose(output_custom_op, torch.tensor([5.0]))
print("Outputs are intentionally different, as expected (Add changed to Sub)!")

解説:

  • self.current_proxy_codegen.var_map[node] は、現在のノードの結果が格納される変数名を取得するために使用されます。
  • map_arg は、ノードの引数 (他のノードや定数) を、生成されるコード内で使用される適切な変数名やリテラルに変換するためのユーティリティ関数です。
  • if node.target == torch.add: で、もしノードが torch.add 関数を呼び出している場合、そのコード生成を torch.sub を呼び出すように変更しています。
  • emit_node メソッドをオーバーライドしています。このメソッドは、Graph の各 Node オブジェクトに対して、対応する Python コードの行を生成する責任を持ちます。
  • この例では PythonCode を継承しています。PythonCodeCodeGen のサブクラスであり、より多くのデフォルト実装を提供してくれるため、カスタマイズが容易になります。

torch.fx.Graph.set_codegen() は、torch.fx を使ってモデルのコンパイルや最適化、特定のバックエンドへのデプロイを行う際に非常に強力なカスタマイズポイントとなります。

  • GraphModule を再構築/再コンパイルする: set_codegen() を呼び出した後、変更を反映させるために GraphModule を再作成するか、gm.recompile() を呼び出す必要があります。
  • メソッドをオーバーライドする:
    • generate_prologue(): forward メソッドの冒頭部分をカスタマイズ。
    • additional_globals(): 生成コードからアクセスできるグローバル変数を追加。
    • emit_node(): 各 Node からのコード生成ロジックを細かく制御 (より高度)。
    • 他にも process_inputs(), process_outputs(), generate_epilogue() などがあります。
  • CodeGen クラスを継承する: 通常は torch.fx.codegen.CodeGen または torch.fx.codegen.PythonCode を継承します。


torch.fx.Graph.set_codegen() は、PyTorch FX の Graph を Python コードに変換する際の、末端のコード生成ロジックを直接制御するための非常に低レベルなAPIです。これは非常に強力ですが、その分、FX の内部実装やコード生成のメカニズムについて深い理解が求められます。

しかし、PyTorch には、この set_codegen() を直接使うよりも、より高レベルで汎用的な代替手段がいくつか存在します。これらの方法は、多くの場合、特定の最適化や変換の目的を達成するために利用され、必ずしも独自のコード生成器を一から書く必要はありません。

torch.fx を用いたグラフ変換 (Graph Transformations)

set_codegen() がコード生成の最終段階をカスタマイズするのに対し、torch.fx の中核的な目的は、グラフ自体を変換することです。つまり、グラフのノードを追加、削除、変更することで、目的の動作を実現します。そして、変更されたグラフはデフォルトのコード生成器によって Python コードに変換されます。

  • torch.fx.Transformer: グラフのノードを変換するための、より構造化された方法を提供します。特定の op タイプ (例: call_function, call_module) に対応するメソッドをオーバーライドすることで、そのノードの処理ロジックをカスタマイズできます。
  • replace_pattern: 特定のサブグラフパターンを別のサブグラフで置換するための便利なAPIです。コード生成器をいじることなく、高レベルなグラフ最適化を実現できます。
    • 用途: 一般的な最適化(例: ReLU を (x > 0) * x に分解する)や、バックエンド固有の複合演算に置き換える場合。
  • 直接的なグラフ操作: torch.fx.Graph オブジェクトは、その nodes リストを直接操作できます。ノードをループ処理し、特定の条件に基づいてノードを置換したり、新しいノードを挿入したり、不要なノードを削除したりすることで、計算グラフの構造を変更できます。
    • : 畳み込み層とバッチ正規化層の融合 (Conv-BN fusion) や、特定の活性化関数を別のものに置き換えるなどの最適化。

set_codegen() との違い: set_codegen() が「最終的にどのようなPythonコードが出力されるか」を制御するのに対し、グラフ変換は「どのような計算グラフが出力されるか」を制御します。多くの場合、グラフ変換の方がより高レベルで、問題解決に適しています。カスタムのコード生成が必要になるのは、FX が生成する標準の Python 表現では対応できない、非常に特殊なケースに限られます。

torch.compile (TorchDynamo, AOTAutograd, Inductor)

PyTorch 2.0 で導入された torch.compile は、PyTorch モデルのパフォーマンスを大幅に向上させるための主要な機能です。これは内部的に TorchDynamo (PythonバイトコードをキャプチャしてFX Graphに変換)、AOTAutograd (autograd の事前にコンパイルされたグラフを生成)、Inductor (高パフォーマンスなコードを生成するバックエンド) などの技術を組み合わせています。

  • カスタムバックエンド: torch.compile は、カスタムのコンパイルバックエンドを登録する機能を提供します。ユーザーは独自のバックエンド関数を定義し、それを torch.compile(model, backend=my_backend) のように渡すことができます。
    • このカスタムバックエンド関数は、torch.fx.GraphModuleexample_inputs を受け取り、最適化された Callable (通常は C++/CUDA カーネル、または JIT でコンパイルされたコード) を返します。
    • このアプローチは、set_codegen() よりもはるかに高レベルであり、ユーザーは特定のハードウェアやランタイムに特化したコンパイラ(例: TensorRT, OpenVINO, TVM など)と統合することができます。
    • コード生成のロジックは、このバックエンド内で抽象化され、ユーザーは直接 CodeGen を書く必要がなくなります。

set_codegen() との違い: torch.compile のカスタムバックエンドは、PyTorch の計算グラフ全体を別の実行形式(通常は低レベルな最適化されたコード)に変換することを目的としています。set_codegen() はあくまで Python コードを生成しますが、torch.compile はPythonのレイヤーをバイパスして、より効率的なネイティブコードやJITコンパイルされたコードを生成します。

TorchScript (JIT コンパイル)

torch.jit モジュールは、PyTorch モデルをシリアライズ可能で最適化されたグラフ形式に変換するための機能を提供します。これは主にモデルのデプロイメント(C++環境などPython以外の環境で実行する場合)や、PyTorch の内部JITコンパイラによる最適化に利用されます。

  • カスタムC++/CUDAオペレータの統合: TorchScript は、カスタムのC++やCUDAで書かれたオペレータをグラフに組み込む機能も提供します。これにより、Pythonでは遅い部分を高速なネイティブコードで実装し、それをモデルに統合できます。
  • torch.jit.trace: 実際の入力例を実行して、その実行パスを記録し、線形なグラフ(トレース)を構築します。制御フローは失われますが、シンプルなモデルのキャプチャには適しています。
  • torch.jit.script: Pythonコードを直接解析し、TorchScript のサブセットに変換します。これにより、制御フロー(if文、forループなど)を保持したままグラフを構築できます。

set_codegen() との違い: set_codegen() が Python のコードを生成するのに対し、TorchScript は C++ で実行可能な中間表現(IR)を生成します。これにより、Python インタープリタのオーバーヘッドなしにモデルを実行でき、さらにTorchScriptコンパイラによる最適化(演算融合など)が適用されます。set_codegen() はPythonのレイヤーに留まるため、TorchScriptのようなC++レベルでの最適化やデプロイメントの柔軟性はありません。

torch.func (functorch)

torch.func (以前は functorch) は、JAX のような関数変換 (vmap, grad, jvp など) を PyTorch で可能にするためのライブラリです。これは直接的なコード生成の代替というよりは、高度なテンソル操作や最適化を可能にするためのツールですが、その一部の機能は FX と密接に関連しています。

  • AOTAutograd: torch.func の内部にある AOTAutograd は、順伝播と逆伝播の計算グラフを事前に抽出・コンパイルすることで、自動微分(Autograd)のオーバーヘッドを削減し、パフォーマンスを向上させます。これは、torch.compile の重要な構成要素の一つでもあります。

set_codegen() との違い: torch.func は、主に PyTorch の自動微分やバッチ処理のメカニズムに焦点を当てています。set_codegen() が最終的なコードの「見た目」や「構造」をカスタマイズするのに対し、torch.func は計算の「実行方法」をより効率的にするための関数変換を提供します。

  • vmapgrad などの関数変換を使って、複雑なテンソル操作や微分計算を効率的に行いたい: torch.func を利用します。
  • Pythonに依存しない形でモデルを保存・デプロイしたい、あるいはC++で実行したい: TorchScript (JITコンパイル) が選択肢となります。
  • モデル全体のパフォーマンスを最大化し、特定のバックエンド(GPU、TPUなど)にデプロイしたい: torch.compile を利用するのが最も強力で現代的なアプローチです。カスタムバックエンドが必要な場合は、torch.compile のバックエンドインターフェースを使用します。
  • FX の Graph を直接操作して、特定のノード置換や融合を行いたい: torch.fx のグラフ変換(replace_pattern、直接的なノード操作、Transformer)が最も適しています。