PyTorchモデル最適化の選択肢:FXのcall_module()と代替手法を比較

2025-05-31

Tracer は、PyTorch モデルの forward メソッドを実行(シンボリックトレース)しながら、その中で行われる操作を記録していきます。この記録された操作が、後で Graph という形で表現されます。

Tracer.call_module() の役割

Tracer.call_module() は、Tracer がモデルのトレース中に 別の torch.nn.Module のインスタンスの forward() メソッドが呼び出された ことを検出したときに内部的に使用されるメソッドです。

具体的には、Tracer は以下の3種類の呼び出しを区別して記録します。

  1. call_function: torch.relutorch.add のような単体の関数呼び出し。
  2. call_method: テンソルの .view().add() のような、オブジェクトのメソッド呼び出し。
  3. call_module: torch.nn.Lineartorch.nn.Conv2d のように、torch.nn.Module のインスタンスが呼び出された場合(実際にはその forward() メソッドが呼び出される)。

Tracer.call_module() は、この3番目のケース、つまりモジュールが呼び出された場合に、対応する NodeGraph に追加する役割を担います。

シンボリックトレースの仕組み

Tracer は、実際のデータではなく「Proxy」と呼ばれるシンボリックな値を使ってモデルの forward メソッドを実行します。この Proxy オブジェクトが操作されるたびに、Tracer はその操作を GraphNode として記録します。

例えば、以下のようなモデルがあったとします。

import torch
import torch.nn as nn
import torch.fx as fx

class MySubModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x):
        return self.linear(x)

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub_module = MySubModule()
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.sub_module(x) # ここで MySubModule の forward が呼び出される
        x = self.relu(x)
        return x

model = MyModule()
traced_model = fx.symbolic_trace(model)
print(traced_model.graph)

このコードを実行すると、traced_model.graph には以下のような情報が含まれます(簡略化しています):

graph():
    %x : [#users=1] = placeholder[target=x]
    %sub_module_linear : [#users=1] = call_module[target=sub_module.linear](args = (%x,), kwargs = {})
    %relu : [#users=1] = call_function[target=torch.relu](args = (%sub_module_linear,), kwargs = {})
    return relu

ここで、sub_module.linear の呼び出しは call_module として記録されています。Tracer.call_module() は、この sub_module.linearforward メソッドが呼び出された際に、その操作を GraphNode として追加するために利用されます。

Tracer.call_module()torch.fx.Tracer クラスのメソッドであり、必要に応じてこれをオーバーライドすることで、特定のモジュールのトレース動作をカスタマイズすることができます。例えば、ある特定のモブジュールについては内部にトレースせず、単一のノードとして扱いたい場合などに利用されます。



TraceError: symbolically traced variables cannot be used as inputs to control flow (データ依存の制御フロー)

エラーの原因
torch.fx は、静的なグラフ構造をキャプチャすることを目的としています。モデルの forward メソッド内で、テンソルの値に依存するような動的な制御フロー(if 文、for ループなど)があると、トレースが失敗することがよくあります。call_module がこの問題に直接関係しているわけではありませんが、内部のモジュールがこのような動的制御フローを含んでいる場合に発生します。


import torch
import torch.nn as nn
import torch.fx as fx

class DynamicModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.linear2 = nn.Linear(5, 5)

    def forward(self, x):
        # x の値に依存する if 文
        if x.sum() > 0: # <-- ここが問題
            return self.linear1(x)
        else:
            return self.linear2(self.linear1(x))

model = DynamicModule()
# エラーが発生する可能性が高い
# traced_model = fx.symbolic_trace(model)

トラブルシューティング

  • torch.compile の検討
    PyTorch 2.0 以降の torch.compile は、より複雑な制御フローや動的な形状に対応できる場合があります。FX を直接使用するよりも、torch.compile を介して利用する方が、多くのケースで恩恵を受けられることがあります。

  • Tracer.is_leaf_module() のオーバーライド
    特定のモジュールが内部の制御フローを持つためトレースできない場合、そのモジュールを「リーフモジュール」としてマークすることで、内部をトレースせずに単一の call_module ノードとして扱わせることができます。

    class CustomTracer(fx.Tracer):
        def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
            if isinstance(m, DynamicModule): # DynamicModule をリーフモジュールとして扱う
                return True
            return super().is_leaf_module(m, module_qualified_name)
    
    # traced_model = CustomTracer().trace(model) # CustomTracer を使用してトレース
    

    ただし、この方法だと、リーフモジュールの内部の計算グラフが失われるため、その部分の最適化は行えません。

  • データ依存の制御フローの回避
    可能な限り、テンソルの値に依存する制御フローを避けるようにモデルをリファクタリングします。例えば、torch.where() のようなテンソル操作に置き換えられる場合があります。

モジュールクラス情報の損失 (torch.fx.symbolic_trace() loses module class information)

エラーの原因
torch.fx.symbolic_trace を使用してモデルをトレースすると、元のサブモジュールのクラス情報が失われることがあります。トレースされた GraphModule のサブモジュールは、汎用の torch.nn.Module のインスタンスとして表示されることがあり、これは FX の現在のトレーサーの設計によるものです。FX は計算グラフに焦点を当てており、元のモジュール階層や詳細な型情報を必ずしも保持しません。


import torch
import torch.nn as nn
import torch.fx as fx

class MyBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 3)

    def forward(self, x):
        return self.conv(x)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = MyBlock()

    def forward(self, x):
        return self.block(x)

model = MyModel()
gm = fx.symbolic_trace(model)

# 期待される動作: gm.block も MyBlock のインスタンスであること
# 実際: gm.block は torch.nn.Module のインスタンスになっている場合がある
# print(type(model.block)) # <class '__main__.MyBlock'>
# print(type(gm.block))    # <class 'torch.nn.modules.module.Module'> のような出力

トラブルシューティング

  • GraphModule の再構築
    トレース後に GraphModule を操作して、ノードに対応するサブモジュールを手動で追加・置換する際に、元のクラス情報に基づいて新しいモジュールをインスタンス化することも考えられます。
  • FX の設計理解
    これは FX の「意図された」動作の一部であることが多いです。FX は計算グラフに焦点を当てており、元のモジュールの型情報が必要な場合は、別の方法で情報を保持する必要があります。

未サポートの操作や組み込み関数

エラーの原因
torch.fx はすべての Python の操作や PyTorch の関数をトレースできるわけではありません。特に、以下のようなケースで問題が発生しやすいです。

  • 非決定的な操作
    トレースは決定的なグラフを生成するため、乱数生成など非決定的な操作は問題となることがあります。
  • データに依存する arange や ones などのテンソルコンストラクタ
    引数がProxyオブジェクトである場合、トレースできないことがあります。
  • Python の組み込み関数や標準ライブラリのモジュール
    open(), len() (テンソルに適用される場合を除く)、list.append() など。


import torch
import torch.nn as nn
import torch.fx as fx

class ProblematicModule(nn.Module):
    def forward(self, x):
        # テンソルのサイズに依存する arange
        # size = x.size(1) # <-- Proxy オブジェクト
        # return torch.arange(size, dtype=torch.long) # <-- 問題となる可能性

        # または
        # return list(x.shape) # <-- Python組み込み関数の使用
        return x.shape[0] * x.shape[1] # FXはテンソルのshapeアクセスは比較的得意

トラブルシューティング

  • torch.compile
    前述の通り、torch.compile は FX より広範なコードパターンに対応できるため、試してみる価値があります。
  • カスタムTracerでのハンドリング
    未サポートの操作をカスタムの Tracer で処理し、例えばその部分を「リーフ」として扱ったり、トレース中に特定のロジックを注入したりすることが考えられます。
  • FXフレンドリーなコード
    トレース可能な操作やモジュールを使用するようにコードを書き直します。Python の組み込み関数ではなく、可能な限り PyTorch の同等の操作(例: torch.numel の代わりに x.numel()list(x.shape) の代わりに x.shape を直接使うなど)を使用します。

モジュールの状態変更 (in-place operations, buffer/parameter manipulation)

エラーの原因
モデルの forward メソッド内で、モジュール自身が保持するパラメーターやバッファ、または中間テンソルをインプレースで変更するような操作は、FX のトレースと相性が悪い場合があります。また、トレーシング中にモジュール内のパラメーターやバッファが動的に変更されると、グラフの静的な性質が損なわれる可能性があります。

トラブルシューティング

  • 状態の管理
    register_bufferregister_parameter で登録されたバッファやパラメーターの動的な更新がトレース時に問題となる場合、そのロジックを forward から分離するか、トレース後に GraphModule を変更するなどの工夫が必要です。
  • インプレース操作の回避
    可能な限りインプレース操作 (.add_(), .mul_()) を避け、非インプレースな操作 (+, *) を使用します。

ModuleNotFoundError: No module named 'torch.fx'

エラーの原因
これは torch.fx 固有のエラーというよりも、PyTorch のバージョンが古いか、インストールが不完全である可能性が高いです。torch.fx は PyTorch 1.8.0 以降で導入されました。

トラブルシューティング

  • PyTorch のアップグレード
    最新の安定版 PyTorch にアップグレードします。
    pip install torch torchvision torchaudio --upgrade --index-url https://download.pytorch.org/whl/cu118 # (ご自身のCUDAバージョンに合わせて)
    
    または CPU 版:
    pip install torch torchvision torchaudio --upgrade
    
  • PyTorch のバージョン確認
    torch.__version__ で現在インストールされている PyTorch のバージョンを確認します。
  • PyTorch Forums や GitHub Issues
    類似の問題が報告されていないか、PyTorch のフォーラムや GitHub Issues を検索します。多くの場合、同様の課題に直面している開発者がいます。
  • GraphModule.print_tabular() の利用
    トレースが部分的に成功した場合や、生成されたグラフの構造を確認したい場合に、traced_model.graph.print_tabular() を使用すると、ノードの種類 (call_module, call_function など) とその引数が表形式で表示され、デバッグに非常に役立ちます。
  • シンプルな例で試す
    複雑なモデルで問題が発生した場合、問題の箇所を特定するために、その部分だけを切り出したシンプルなモデルを作成してトレースを試します。
  • エラーメッセージをよく読む
    PyTorch FX のエラーメッセージは、問題の箇所や原因について具体的なヒントを提供してくれることが多いです。


基本的なトレースと call_module ノードの確認

まずは、torch.fx.Tracer がどのように call_module ノードを生成するかを確認する基本的な例です。

import torch
import torch.nn as nn
import torch.fx as fx

# 1. サブモジュールを含むシンプルなモデルを定義
class MySubModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_sub = nn.Linear(10, 5)

    def forward(self, x):
        print(f"  Inside MySubModule.forward with input shape: {x.shape}")
        return self.linear_sub(x)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 10, 3)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2)
        self.sub_module_instance = MySubModule() # MySubModuleのインスタンス

    def forward(self, x):
        print(f"Inside MyModel.forward with input shape: {x.shape}")
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        # ここで MySubModule の forward が呼び出される
        # FXのTracerは、この呼び出しを call_module ノードとして記録する
        x = self.sub_module_instance(x.view(-1, 10)) # Linear層に合わせるためにreshape
        return x

# 2. モデルのインスタンス化
model = MyModel()
example_input = torch.randn(1, 3, 28, 28) # サンプル入力

# 3. モデルのシンボリックトレース
# fx.symbolic_trace() は内部で fx.Tracer を使用する
print("\n--- Performing symbolic trace ---")
traced_model = fx.symbolic_trace(model)

# 4. 生成されたグラフの確認
print("\n--- Generated Graph ---")
print(traced_model.graph)

# 5. ノードの種類を確認
print("\n--- Analyzing Graph Nodes ---")
for node in traced_model.graph.nodes:
    print(f"Node: {node.name}, Op: {node.op}, Target: {node.target}")
    if node.op == 'call_module':
        # call_module ノードの場合、target はモジュールの完全修飾名
        print(f"  This is a call_module node! Module target: {node.target}")

# 6. トレースされたモデルの実行(動作確認)
print("\n--- Running traced model ---")
output_traced = traced_model(example_input)
print(f"Output from traced model shape: {output_traced.shape}")

# オリジナルモデルの実行と比較(sanity check)
# output_original = model(example_input)
# print(f"Output from original model shape: {output_original.shape}")
# print(f"Outputs match: {torch.allclose(output_traced, output_original)}")

解説

  • print(traced_model.graph) の出力を見ると、call_module ノードがリストされていることがわかります。特に、sub_module_instance_linear のような名前で記録されるノードがそれです。(FXはサブモジュールのさらに内部のモジュールまで展開してトレースします)
  • fx.symbolic_trace(model) を実行すると、Tracer はこの self.sub_module_instance の呼び出しを検出し、対応する call_module ノードを生成します。
  • MyModelforward メソッド内で self.sub_module_instance(x.view(-1, 10)) が呼び出されています。

カスタム Tracer を使用して is_leaf_module() をオーバーライドする例

特定のモジュールを FX のトレース対象から外したい(つまり、そのモジュール全体を一つの call_module ノードとして扱いたい)場合に、Tracer.is_leaf_module() メソッドをオーバーライドします。これは、call_module の挙動に影響を与える典型的な例です。

import torch
import torch.nn as nn
import torch.fx as fx

class ComplexModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.randn(10, 10))
        self.sub_linear = nn.Linear(10, 10)

    def forward(self, x):
        # このモジュールは内部で複雑なロジックや
        # FXがサポートしない操作を含んでいると仮定
        if x.sum() > 0: # データ依存の制御フロー (通常はトレース不可)
            x = self.sub_linear(x) + self.param @ x.T # テンソル乗算など
        else:
            x = self.sub_linear(x) - self.param @ x.T
        return x

class MyModelWithComplexPart(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_in = nn.Linear(10, 10)
        self.complex_part = ComplexModule() # トレースしたくないモジュール
        self.linear_out = nn.Linear(10, 5)

    def forward(self, x):
        x = self.linear_in(x)
        x = self.complex_part(x) # この呼び出しを単一の call_module ノードとして保持したい
        x = self.linear_out(x)
        return x

# 1. モデルのインスタンス化
model = MyModelWithComplexPart()
example_input = torch.randn(1, 10)

# 2. カスタムTracerの定義
class CustomTracer(fx.Tracer):
    def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
        # ComplexModuleのインスタンスであれば、その内部をトレースしない(リーフモジュールとする)
        if isinstance(m, ComplexModule):
            print(f"Treating {module_qualified_name} as a leaf module.")
            return True
        # それ以外のモジュールはデフォルトの挙動に従う
        return super().is_leaf_module(m, module_qualified_name)

# 3. カスタムTracerを使用してモデルをトレース
print("\n--- Performing trace with CustomTracer (ComplexModule as leaf) ---")
tracer = CustomTracer()
traced_model = tracer.trace(model)
graph_module = fx.GraphModule(model, traced_model) # GraphModuleを作成

# 4. 生成されたグラフの確認
print("\n--- Generated Graph with CustomTracer ---")
print(graph_module.graph)

# 5. ノードの種類を確認
print("\n--- Analyzing Graph Nodes with CustomTracer ---")
for node in graph_module.graph.nodes:
    print(f"Node: {node.name}, Op: {node.op}, Target: {node.target}")
    if node.op == 'call_module':
        # complex_part が単一の call_module ノードとして存在することを確認
        if node.target == 'complex_part':
            print(f"  Detected 'complex_part' as a single call_module node.")

# 6. トレースされたモデルの実行
print("\n--- Running traced model with CustomTracer ---")
output_traced = graph_module(example_input)
print(f"Output from traced model shape: {output_traced.shape}")

解説

  • 結果のグラフを見ると、complex_part が一つの call_module ノードとして存在し、その内部の linearparam @ x.T のような操作はグラフには現れていないことが確認できます。
  • is_leaf_module の中で、もし現在のモジュールが ComplexModule のインスタンスであれば True を返します。これにより、FX は ComplexModule の内部に入り込まず、その呼び出しを単一の call_module ノードとして記録します。
  • CustomTracer を定義し、is_leaf_module メソッドをオーバーライドしています。
  • ComplexModule はデータ依存の制御フロー(if x.sum() > 0)を含んでおり、通常の方法ではトレースが難しいか失敗する可能性があります。

call_module ノードは、モデルの変換や最適化の際に非常に重要です。ここでは、call_module ノードを識別し、何らかの変換を行う例を示します。

import torch
import torch.nn as nn
import torch.fx as fx

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

model = SimpleModel()
traced_model = fx.symbolic_trace(model)

# グラフのイテレーションと特定の call_module ノードの識別
print("\n--- Iterating through the graph and identifying specific call_module nodes ---")
for node in traced_model.graph.nodes:
    if node.op == 'call_module':
        # Linear層の後に特定の変換を適用したい場合
        if 'linear' in node.target: # 'linear1' や 'linear2' にマッチ
            print(f"Found a Linear module call: {node.target}. You could apply a transformation here.")
            # 例: このLinear層の後にBatchNorm層を追加する変換を検討する

# ここから先はGraphを実際に変更する例 (Transformations)
# 例: Linear層の直後にカスタムの活性化関数を追加する
# 実際には、この処理は pass を実装する形で行われることが多い

def transform_graph(graph: fx.Graph) -> fx.Graph:
    new_graph = fx.Graph()
    env = {} # 古いノードと新しいノードのマッピング

    # カスタムの活性化関数
    class CustomActivation(nn.Module):
        def forward(self, x):
            return torch.sigmoid(x) * x # Swishのようなもの

    # 新しいモジュールを GraphModule に追加するためのコンテナ
    new_modules = {}

    for node in graph.nodes:
        # 新しいノードのオペランドを env から取得
        new_args = tuple(env[arg] for arg in node.args)
        new_kwargs = {k: env[v] if isinstance(v, fx.Node) else v for k, v in node.kwargs.items()}

        if node.op == 'call_module' and 'linear' in node.target:
            # オリジナルの call_module ノードを新しいグラフに追加
            new_node = new_graph.node_copy(node, new_args, new_kwargs)
            env[node] = new_node

            # その直後にカスタム活性化関数を追加
            activation_name = f"{node.name}_custom_act"
            # GraphModuleに新しいモジュールを追加
            new_modules[activation_name] = CustomActivation()

            with new_graph.insert_after(new_node):
                # 新しい call_module ノードを作成して挿入
                custom_act_node = new_graph.call_module(activation_name, args=(new_node,))
                env[node] = custom_act_node # 後続のノードがこの新しいノードを参照するように更新

        else:
            # その他のノードはそのままコピー
            new_node = new_graph.node_copy(node, new_args, new_kwargs)
            env[node] = new_node

    new_graph.lint() # グラフの整合性チェック
    return new_graph, new_modules

print("\n--- Transforming the graph: Adding CustomActivation after Linear layers ---")
transformed_graph, added_modules = transform_graph(traced_model.graph)

# 新しい GraphModule を作成
# 元のモデルのサブモジュールをコピー
transformed_model = fx.GraphModule(traced_model, transformed_graph)
for name, module in traced_model.named_modules():
    if name != '': # ルートモジュールは除く
        setattr(transformed_model, name, module)
# 追加された新しいモジュールを GraphModule に設定
for name, module in added_modules.items():
    setattr(transformed_model, name, module)

print("\n--- Transformed Graph ---")
print(transformed_model.graph)

# Transformed GraphModule の実行
print("\n--- Running transformed model ---")
output_transformed = transformed_model(example_input)
print(f"Output from transformed model shape: {output_transformed.shape}")
  • 最終的に、変換されたグラフと追加されたモジュールを持つ新しい GraphModule を作成し、実行できることを示しています。
  • new_graph.call_module() は、新しい call_module ノードを生成するために内部で Tracer のロジックを使用するようなものです。
  • その call_module ノードの直後に、新しい CustomActivation モジュールの呼び出し(これも call_module ノードになります)を追加しています。
  • node.op == 'call_module' かつ linearnode.target に含まれる場合、その linear 層の call_module ノードを特定します。
  • transform_graph 関数は、元のグラフを走査し、新しいグラフを構築します。


FX 自体が PyTorch の最新の最適化技術(特に PyTorch 2.0 以降)の基盤となっているため、FX の代替というよりも、FX が解決しようとしている問題に対して異なるアプローチを取る方法、と考えるのが適切です。

torch.compile() (PyTorch 2.0 以降の推奨)

torch.compile() は、PyTorch 2.0 で導入された最も推奨される最適化手法です。これは内部的に TorchDynamoTorchInductor を使用し、Python のバイトコードレベルでグラフをキャプチャし、より高速な実行のために最適化されたカーネルにコンパイルします。

特徴

  • 高速化
    オペレーターフュージョン、メモリ最適化、並列化などにより、大幅な高速化が期待できます。
  • バックエンドの柔軟性
    TorchInductor だけでなく、さまざまなバックエンド(例: ONNX Runtime, TensorRT など)と連携して最適化を実行できます。
  • 動的制御フローへの対応
    torch.fx.symbolic_trace() が苦手とするデータ依存の制御フロー(if文やforループなど)に対しても、適切にグラフを分割("graph break")して対応できます。これにより、より多くのモデルを最適化できます。
  • 自動的なグラフキャプチャ
    ユーザーが明示的にグラフを構築する必要はありません。既存の PyTorch モデルをそのまま torch.compile() でラップするだけです。

torch.compile() の例

import torch
import torch.nn as nn

class MySimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        return self.linear2(x)

model = MySimpleModel()
example_input = torch.randn(1, 10)

# モデルをコンパイル
compiled_model = torch.compile(model)

# コンパイルされたモデルを実行
output = compiled_model(example_input)
print(f"Output from compiled model: {output}")

# 複数回実行することで、コンパイルの恩恵が最大化される
for _ in range(10):
    _ = compiled_model(example_input)

torch.fx.Tracer.call_module() との関連性
torch.compile() は内部的に TorchDynamo を使用し、これが PyTorch のバイトコードを分析して torch.fx グラフを生成します。したがって、call_module ノードは内部的に引き続き生成されますが、ユーザーが直接 Tracercall_module() を操作する必要がなくなります。

TorchScript (torch.jit)

TorchScript は、PyTorch モデルをシリアライズ可能で最適化可能な形式に変換するためのツールです。これは、Python の実行環境なしでモデルを実行したい場合(C++ 環境でのデプロイなど)や、Python の GIL の制約を受けずにマルチスレッドで推論を実行したい場合に特に有用です。

TorchScript には主に2つの変換方法があります。

  • Scripting (torch.jit.script)
    Python のソースコードを直接解析し、TorchScript の中間表現にコンパイルします。これにより、データに依存しない制御フロー(forループやif文など)もキャプチャできます。

    import torch
    import torch.nn as nn
    
    class MyScriptingModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear1 = nn.Linear(10, 5)
            self.linear2 = nn.Linear(5, 1)
    
        @torch.jit.script
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            if x.sum() > 0: # Scriptingはこれをキャプチャできる
                x = self.linear1(x)
            else:
                x = self.linear2(x) # こちらのパスが実行されることもある
            return x
    
    model = MyScriptingModel()
    example_input_pos = torch.ones(1, 10)
    example_input_neg = -torch.ones(1, 10)
    
    # モデルをスクリプト化(デコレータを使っているため自動的に行われる)
    # @torch.jit.script を使用しない場合は torch.jit.script(MyScriptingModel()) のようにする
    
    # スクリプト化されたモデルを実行
    output_pos = model(example_input_pos)
    output_neg = model(example_input_neg)
    print(f"Output (pos) from scripted model: {output_pos}")
    print(f"Output (neg) from scripted model: {output_neg}")
    
  • Tracing (torch.jit.trace)
    実際の入力を使ってモデルのフォワードパスを実行し、その過程で実行された操作を記録することで、静的なグラフを構築します。torch.fx.symbolic_trace() と似ていますが、TorchScript は特定の入力に対する実際の実行パスを記録するため、データ依存の制御フローがある場合は注意が必要です。

    import torch
    import torch.nn as nn
    
    class MyTracingModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(10, 5)
            self.bn = nn.BatchNorm1d(5)
    
        def forward(self, x):
            x = self.linear(x)
            # if x.sum() > 0: # Tracingではこのような動的制御フローはキャプチャされない
            #     return self.bn(x)
            return self.bn(x)
    
    model = MyTracingModel()
    example_input = torch.randn(1, 10)
    
    # モデルをトレース
    traced_script_module = torch.jit.trace(model, example_input)
    
    # トレースされたモデルを実行
    output = traced_script_module(example_input)
    print(f"Output from traced script model: {output}")
    
    # 保存とロード
    traced_script_module.save("my_traced_model.pt")
    loaded_script_module = torch.jit.load("my_traced_model.pt")
    

torch.fx.Tracer.call_module() との関連性
TorchScript も、内部的にはモデルのモジュール構造を認識し、それらの呼び出しをグラフの一部として扱います。しかし、torch.fx がグラフ変換の柔軟性をPythonレベルで提供するのに対し、TorchScript はより低レベルなJITコンパイルとデプロイに焦点を当てています。torch.fx は Python のデータフローに密接ですが、TorchScript はより厳密な型システムを持ち、C++ランタイムで実行可能な形式を目指しています。

Eager Mode (通常のPyTorch実行)

PyTorch のデフォルトの実行モードは「Eager Mode」です。これは、コードが書かれた通りに即座に操作が実行されることを意味します。グラフは動的に構築され、必要に応じて破棄されます。

特徴

  • デメリット
    グラフ全体の最適化(オペレーターフュージョンなど)が行われにくく、パフォーマンスが低下する可能性があります。Python のオーバーヘッドも発生します。
  • 動的グラフ
    モデルの構造を動的に変更したり、データに依存する制御フローを自由に記述したりできます。
  • 高い柔軟性
    Python の全機能を活用でき、デバッグが容易です。研究開発やプロトタイピングに最適です。

Eager Mode の例

import torch
import torch.nn as nn

class MyEagerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.dropout = nn.Dropout(0.5) # Eager Modeでは自由に使える
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.linear1(x)
        # 訓練時と推論時で異なる動作
        if self.training:
            x = self.dropout(x)
        x = torch.relu(x) # functional APIも自由に使える
        return self.linear2(x)

model = MyEagerModel()
example_input = torch.randn(1, 10)

# モデルを実行(Eager Mode)
model.train() # 訓練モード
output_train = model(example_input)
print(f"Output (train) from eager model: {output_train}")

model.eval() # 推論モード
output_eval = model(example_input)
print(f"Output (eval) from eager model: {output_eval}")

torch.fx.Tracer.call_module() との関連性
Eager Mode では、torch.fx.Tracer.call_module() のような「グラフキャプチャ」の概念は直接適用されません。各 nn.Module の呼び出しは、単にその forward メソッドが Python インタープリタによって実行されるだけです。FX や TorchScript は、この Eager Mode の柔軟性を保ちつつ、パフォーマンス最適化のために「グラフ」を抽出しようとする試みと位置づけられます。

非常に高度なケースでは、PyTorch が提供する高レベルのモジュールや関数では不十分な場合があり、より低レベルな操作を直接記述することがあります。これには、以下のような方法が含まれます。

  • C++/CUDA 拡張
    Python から C++ や CUDA カーネルを呼び出すことで、PyTorch の演算では達成できないパフォーマンスを引き出します。
  • torch.autograd.Function のカスタム実装
    フォワードパスとバックワードパスを手動で定義し、特定の演算の挙動を最適化します。

これらの方法は、特定の演算に対して極限のパフォーマンスを求める場合に有用ですが、開発の複雑性は大幅に増加します。

torch.fx.Tracer.call_module() との関連性
これらの低レベルな操作は、通常 torch.fx のシンボリックトレースの対象外となります。FX は PyTorch の高レベルな API (torch.nn.Module, torch.nn.functional など) を対象としており、カスタム C++/CUDA カーネルや torch.autograd.Function の内部ロジックは FX グラフには現れません。

torch.fx.Tracer.call_module()torch.fx を用いたグラフ変換の内部的な要素ですが、PyTorch にはモデルの最適化やデプロイに関する様々なアプローチが存在します。

方法特徴call_module()との関連性ユースケース
torch.compile()PyTorch 2.0 以降の推奨。PythonバイトコードからFXグラフを自動生成し、最適化されたカーネルにコンパイル。動的制御フローに対応。内部でFXグラフを生成するため、call_moduleノードも間接的に利用される。ユーザーは直接操作しない。既存モデルの高速化、幅広いモデルタイプへの適用。
TorchScriptモデルをシリアライズ可能/最適化可能な形式に変換。C++環境での実行やマルチスレッド推論に。TracingとScriptingがある。トレース/スクリプト化されたグラフ内でモジュール呼び出しが表現される。FXとは異なるIRとデプロイの目的。C++でのデプロイ、モバイル/組み込み環境での実行、Python GILからの解放。
Eager ModePyTorchのデフォルト実行モード。記述した通りに即座に実行。高い柔軟性とデバッグの容易さ。グラフキャプチャの概念は直接適用されない。モジュールのforwardメソッドが直接実行される。研究開発、プロトタイピング、柔軟なモデル設計、デバッグ。
低レベルな操作torch.autograd.FunctionやC++/CUDA拡張によるカスタムカーネル実装。FXグラフの対象外。これらの低レベルな操作はFXでは透過的に扱われるか、トレースの境界となる。特定の演算における極限のパフォーマンス最適化、カスタムハードウェアへの対応。