PyTorch開発者必見!torch.fx.Graph.process_inputs()のすべて

2025-05-31

背景

PyTorch FX は、PyTorch モデルを中間表現であるグラフとして捉え、変換や分析を行うためのツールキットです。このグラフは、モデル内の各演算やパラメータをノードとして表現し、それらの間のデータフローをエッジで表現します。

モデルを FX グラフに変換する際、モデルの forward メソッドに渡される入力は、最初は具体的なテンソルや値として存在します。しかし、FX グラフ内では、これらの入力は placeholder という種類のノードで表現されます。placeholder ノードは、グラフへの入力点を表し、具体的な値は後で供給されることを想定しています。

process_inputs() の役割

process_inputs() メソッドは、以下の主要な処理を行います。

  1. 入力の追跡
    forward メソッドに渡される引数を順番に追跡します。
  2. placeholder ノードとの関連付け
    各入力引数に対応する placeholder ノードをグラフ内で見つけます。通常、グラフ変換の初期段階で、forward メソッドの引数に対応する placeholder ノードが生成されています。
  3. メタデータの付与
    最も重要な点として、process_inputs() は、実際の入力引数から得られる情報(例えば、テンソルの形状、データ型、デバイスなど)を、対応する placeholder ノードのメタデータとして記録します。このメタデータは、後続のグラフ変換や最適化の段階で、各ノードの出力形状やデータ型を推論するために非常に重要になります。

具体例

簡単な例で考えてみましょう。

import torch
import torch.nn as nn
import torch.fx

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)

    def forward(self, x, y):
        out = self.linear(x) + y
        return out

model = MyModule()
graph = torch.fx.Graph()
tracer = torch.fx.Tracer()
traced_graph = tracer.trace(model)

# traced_graph は FX グラフです。
# この時点では、入力 'x' と 'y' に対応する placeholder ノードはありますが、
# 具体的な形状やデータ型はまだ付与されていない可能性があります。

# 具体的な入力を作成
input_tensor = torch.randn(1, 10)
other_tensor = torch.randn(1, 5)

# グラフの forward メソッドのシグネチャを取得
forward_signature = traced_graph.forward.__func__

# 入力引数を名前と値のペアとして準備
inputs = dict(zip(forward_signature.__code__.co_varnames[1:], (input_tensor, other_tensor)))

# process_inputs() を呼び出すことで、placeholder ノードにメタデータが付与される
traced_graph.process_inputs(inputs)

# traced_graph 内の placeholder ノードを確認すると、
# 'x' と 'y' に対応するノードに形状やデータ型などのメタデータが付与されていることがわかります。
for node in traced_graph.nodes:
    if node.op == 'placeholder':
        print(f"Node name: {node.name}, Meta: {node.meta}")

この例では、process_inputs() に具体的な入力テンソル (input_tensor, other_tensor) を渡すことで、グラフ内の 'x''y' という名前の placeholder ノードに、それぞれのテンソルの形状やデータ型といった情報がメタデータとして記録されます。

process_inputs() の重要性

process_inputs() は、FX グラフを効果的に活用するために非常に重要なステップです。入力に関する正確なメタデータがなければ、後続のグラフ変換や最適化(例えば、テンソル形状に基づいた演算の融合など)が正しく行えません。



入力引数の不一致 (Mismatch in Input Arguments)

  • トラブルシューティング
    • forward メソッドの引数名を正確に確認してください。
    • process_inputs() に渡す辞書のキーは、forward メソッドの引数名と完全に一致している必要があります。
    • 引数の順序は、辞書を使用する場合は重要ではありません(名前でマッチングされます)。ただし、リストやタプルで渡す場合は、forward メソッドの引数の順序と一致させる必要があります(非推奨)。
    • torch.fx.Graph.forward.__func__.__code__.co_varnames を使用して、forward メソッドの引数名を取得できます(最初の self は除く)。

  • class MyModule(nn.Module):
        def forward(self, a, b):
            return a + b
    
    model = MyModule()
    graph = torch.fx.Graph()
    tracer = torch.fx.Tracer()
    traced_graph = tracer.trace(model)
    
    input_a = torch.randn(2, 2)
    input_b = torch.randn(2, 2)
    
    # エラー: キーの名前が forward メソッドの引数名と異なる
    inputs_wrong_key = {"x": input_a, "y": input_b}
    try:
        traced_graph.process_inputs(inputs_wrong_key)
    except Exception as e:
        print(f"エラー: {e}")
    
    # エラー: 値の順序が forward メソッドの引数の順序と異なる (名前ベースのマッチングが行われるため、通常はエラーにならないことが多いですが、意図しない結果になる可能性あり)
    inputs_wrong_order = {"b": input_b, "a": input_a}
    traced_graph.process_inputs(inputs_wrong_order)
    print(traced_graph) # 'a' に input_b のメタデータ、'b' に input_a のメタデータが付与される
    
  • エラー
    forward メソッドの引数の名前や順序と、process_inputs() に渡す辞書のキーや値の順序が一致しない場合に発生します。

必要な情報の欠落 (Missing Necessary Information)

  • トラブルシューティング
    • process_inputs() には、具体的なテンソル(または必要な属性を持つオブジェクト)を渡すようにしてください。
    • 特に形状やデータ型が重要な場合は、適切な形状とデータ型を持つテンソルを入力として使用してください。

  • class MyModule(nn.Module):
        def forward(self, x):
            return x.mean()
    
    model = MyModule()
    graph = torch.fx.Graph()
    tracer = torch.fx.Tracer()
    traced_graph = tracer.trace(model)
    
    # スカラー値を渡すと、形状などのメタデータが不明確になる可能性
    scalar_input = torch.tensor(1.0)
    inputs_scalar = {"x": scalar_input}
    traced_graph.process_inputs(inputs_scalar)
    print(traced_graph.nodes['x'].meta) # shape が空のリストになる可能性
    
    # 後続の処理で形状に依存する操作があるとエラーになる可能性
    # 例: traced_graph.graph_module(model) を実行しようとした場合など
    
  • エラー
    process_inputs() に渡す値が、グラフのノードが期待するメタデータを推論するのに十分な情報を持っていない場合に、後続の処理でエラーが発生する可能性があります。

placeholder ノードが存在しない (Placeholder Nodes Do Not Exist)

  • トラブルシューティング
    • torch.fx.Tracer を使用してモデルをトレースし、正しくグラフを生成していることを確認してください。
    • 手動でグラフを構築する場合は、forward メソッドの引数に対応する placeholder ノードを適切な名前で作成していることを確認してください。

  • import torch.fx
    
    # 空のグラフを作成 (通常はトレーサーで生成)
    graph = torch.fx.Graph()
    
    # placeholder ノードを手動で追加 (通常はトレーサーが行う)
    x_node = graph.placeholder(name='input_x')
    
    # process_inputs を呼び出すが、キーが placeholder の名前と一致しない
    try:
        graph.process_inputs({"y": torch.randn(2, 2)})
    except Exception as e:
        print(f"エラー: {e}")
    
  • エラー
    process_inputs() を呼び出す前に、対応する placeholder ノードがグラフ内に存在しない場合、エラーが発生します。これは通常、グラフのトレースが正しく行われていない場合に起こります。

型アノテーションの不一致 (Mismatch in Type Annotations)

  • トラブルシューティング
    • forward メソッドの型アノテーションが、実際に渡す入力の型と矛盾しないようにしてください。
    • 型アノテーションが意図しない制約になっている場合は、見直すことも検討してください。
  • エラー
    モデルの forward メソッドに型アノテーションがある場合、process_inputs() に渡す入力の型がアノテーションと大きく異なる場合に、後続の処理で問題が発生する可能性があります。FX は型アノテーションをヒントとして利用することがあります。

カスタムオブジェクトの処理 (Handling Custom Objects)

  • トラブルシューティング
    • カスタムオブジェクトが後続の処理で必要な属性を持っていることを確認してください。
    • 必要に応じて、カスタムノードを作成してカスタムオブジェクトの処理を明示的に定義することを検討してください。
  • 考慮事項
    forward メソッドがテンソル以外のカスタムオブジェクトを入力として受け取る場合、process_inputs() はそれらのオブジェクトをそのまま placeholder ノードのメタデータとして記録します。後続のグラフ変換や最適化でこれらのカスタムオブジェクトの属性にアクセスしようとする場合、それらの属性が存在することを確認する必要があります。
  • PyTorch FX のドキュメントを参照する
    PyTorch の公式ドキュメントは、FX の詳細な情報や使用例を提供しています。
  • ステップごとに確認する
    複雑な処理を行う場合は、各ステップの出力を確認しながら進めることで、問題の箇所を特定しやすくなります。
  • グラフの内容を確認する
    print(traced_graph) などを使用して、生成された FX グラフのノードやメタデータの内容を確認し、期待通りになっているかを検証してください。
  • エラーメッセージをよく読む
    エラーメッセージは、問題の原因を特定するための重要な情報を提供してくれます。


例1: 基本的な使用法

この例では、簡単な PyTorch モジュールをトレースし、process_inputs() を用いて入力テンソルのメタデータをグラフに付与します。

import torch
import torch.nn as nn
import torch.fx

# 簡単なモジュールを定義
class SimpleModule(nn.Module):
    def forward(self, x):
        return x + 1

# モジュールのインスタンスを作成
model = SimpleModule()

# FX Tracer を作成
tracer = torch.fx.Tracer()

# モデルをトレースしてグラフを取得
graph = tracer.trace(model)

# forward メソッドの引数名を取得
forward_arg_name = list(graph.signature.parameters.keys())[1] # 'self' を除く最初の引数

# 具体的な入力テンソルを作成
input_tensor = torch.randn(2, 3, dtype=torch.float32)

# 入力を辞書形式で準備 (キーは forward メソッドの引数名)
inputs = {forward_arg_name: input_tensor}

# process_inputs() を呼び出して入力メタデータを処理
graph.process_inputs(inputs)

# グラフ内の placeholder ノードを確認
for node in graph.nodes:
    if node.op == 'placeholder':
        print(f"ノード名: {node.name}")
        print(f"メタデータ: {node.meta}")

# グラフを表示
print("\n生成されたグラフ:")
print(graph)

解説

  1. SimpleModule は、入力テンソルに 1 を加えるだけの簡単なモジュールです。
  2. torch.fx.Tracer() を使って model をトレースし、FX グラフ (graph) を取得します。
  3. graph.signature.parameters.keys() を使って forward メソッドの引数名を取得します(ここでは 'x')。
  4. 具体的な入力テンソル input_tensor を作成します。
  5. process_inputs() には、入力の名前 ('x') をキーとし、実際の入力テンソル (input_tensor) を値とする辞書を渡します。
  6. グラフ内の placeholder ノード(ここでは 'x')の meta 属性を確認すると、shapedtypestridedevice などの情報が記録されていることがわかります。これは、process_inputs() が渡された入力テンソルからこれらの情報を抽出し、placeholder ノードに付与したためです。
  7. 最後に、生成されたグラフ全体を表示します。placeholder ノードが入力点として存在し、後続の演算 (add_scalar) に繋がっていることがわかります。

例2: 複数の入力を持つ場合

この例では、複数の入力を持つモジュールをトレースし、それぞれの入力に対して process_inputs() を適用します。

import torch
import torch.nn as nn
import torch.fx

class MultiInputModule(nn.Module):
    def forward(self, a, b):
        return a + b

model = MultiInputModule()
tracer = torch.fx.Tracer()
graph = tracer.trace(model)

# forward メソッドの引数名を取得
forward_arg_names = list(graph.signature.parameters.keys())[1:] # 'self' を除くすべての引数

# 具体的な入力テンソルを作成
input_a = torch.randn(2, 2, dtype=torch.float64)
input_b = torch.randn(2, 2, dtype=torch.float32)

# 入力を辞書形式で準備
inputs = {forward_arg_names[0]: input_a, forward_arg_names[1]: input_b}

# process_inputs() を呼び出す
graph.process_inputs(inputs)

# グラフ内の placeholder ノードを確認
for node in graph.nodes:
    if node.op == 'placeholder':
        print(f"ノード名: {node.name}")
        print(f"メタデータ: {node.meta}")

print("\n生成されたグラフ:")
print(graph)

解説

  1. MultiInputModule は、2つの入力 ab を受け取り、それらを加算するモジュールです。
  2. forward_arg_names には ['a', 'b'] が格納されます。
  3. inputs 辞書では、キーとして引数名 ('a', 'b') を、値として対応する入力テンソル (input_a, input_b) を指定しています。
  4. process_inputs() を呼び出すことで、それぞれの placeholder ノード ('a', 'b') に、対応する入力テンソルのメタデータ(異なる dtype を持つ点に注目)が付与されます。

例3: テンソル以外の入力を持つ場合

process_inputs() は、テンソルだけでなく、Python の基本的な型(数値、文字列、ブール値など)も扱うことができます。ただし、その場合、メタデータとして記録される情報は限られます。

import torch
import torch.nn as nn
import torch.fx

class ModuleWithNonTensorInput(nn.Module):
    def forward(self, x, factor):
        return x * factor

model = ModuleWithNonTensorInput()
tracer = torch.fx.Tracer()
graph = tracer.trace(model)

forward_arg_names = list(graph.signature.parameters.keys())[1:]

input_tensor = torch.randn(2, 2)
input_factor = 2.0

inputs = {forward_arg_names[0]: input_tensor, forward_arg_names[1]: input_factor}

graph.process_inputs(inputs)

for node in graph.nodes:
    if node.op == 'placeholder':
        print(f"ノード名: {node.name}")
        print(f"メタデータ: {node.meta}")

print("\n生成されたグラフ:")
print(graph)

解説

  1. ModuleWithNonTensorInput は、テンソル x と数値 factor を受け取ります。
  2. process_inputs() には、テンソルと数値の両方を含む辞書を渡します。
  3. placeholder ノード 'x' のメタデータにはテンソルの情報が含まれますが、'factor' のメタデータは、その型 (float) 程度の情報になります。
  • process_inputs() を呼び出す前に、torch.fx.Tracer().trace(model) などを用いてグラフを生成しておく必要があります。
  • process_inputs() は、渡された入力から得られる情報を placeholder ノードの meta 属性に記録します。このメタデータは、後続のグラフ変換や最適化の際に利用されます。
  • process_inputs() に渡す辞書のキーは、トレースされたグラフの forward メソッドの引数名と一致している必要があります。


手動でメタデータを設定する

process_inputs() を使わずに、placeholder ノードの meta 属性に直接メタデータを書き込むことができます。これは、入力の形状やデータ型が静的にわかっている場合や、process_inputs() では取得できない追加のメタデータを設定したい場合に有効です。

import torch
import torch.nn as nn
import torch.fx

class SimpleModule(nn.Module):
    def forward(self, x):
        return x + 1

model = SimpleModule()
tracer = torch.fx.Tracer()
graph = tracer.trace(model)

# forward メソッドの入力に対応する placeholder ノードを取得
input_node = None
for node in graph.nodes:
    if node.op == 'placeholder' and node.name == 'x':
        input_node = node
        break

if input_node:
    # 手動でメタデータを設定
    input_node.meta['shape'] = (2, 3)
    input_node.meta['dtype'] = torch.float32
    input_node.meta['stride'] = (3, 1)
    input_node.meta['device'] = torch.device('cpu')

# グラフを表示
print(graph)
print(input_node.meta)

解説

  1. モデルをトレースしてグラフを取得するまでは通常通りです。
  2. グラフ内の placeholder ノードを名前 ('x') で検索します。
  3. 見つかった placeholder ノードの meta 属性に、辞書形式で形状 (shape)、データ型 (dtype)、ストライド (stride)、デバイス (device) などの情報を直接設定します。
  4. process_inputs() を呼び出す必要はありません。

注意点

  • 入力の形状やデータ型が動的に変わるモデルには適していません。
  • この方法では、メタデータを手動で正確に設定する必要があります。誤った情報を設定すると、後続のグラフ変換や最適化に悪影響を及ぼす可能性があります。

グラフ変換中にメタデータを推論・伝播させる

process_inputs() で初期入力のメタデータを設定した後、グラフ変換の過程で各ノードの出力形状やデータ型を推論・伝播させる方法があります。PyTorch FX の変換 API を利用して、カスタムの変換ルールを定義し、ノードの演算に基づいてメタデータを更新できます。

import torch
import torch.nn as nn
import torch.fx
from torch.fx import symbolic_trace

class SimpleModule(nn.Module):
    def forward(self, x):
        return x.mean(dim=1)

model = SimpleModule()
graph = symbolic_trace(model)

# 入力テンソルを作成
input_tensor = torch.randn(2, 3, dtype=torch.float32)

# process_inputs() で初期メタデータを設定
graph.process_inputs({'x': input_tensor})

# 簡単な変換: 平均操作後の形状を推論する (実際にはより複雑なロジックが必要)
for node in graph.nodes:
    if node.op == 'call_method' and node.name == 'mean':
        input_shape = node.args[0].meta['shape']
        node.meta['shape'] = input_shape[:-1] # 最後の次元を削除
        node.meta['dtype'] = node.args[0].meta['dtype']

# グラフを表示し、各ノードのメタデータを確認
print(graph)
for node in graph.nodes:
    print(f"ノード名: {node.name}, メタデータ: {node.meta.get('shape')}, {node.meta.get('dtype')}")

解説

  1. symbolic_trace を使用してグラフを取得します。
  2. process_inputs() で初期入力 x のメタデータを設定します。
  3. グラフのノードを走査し、mean メソッドの呼び出しを見つけます。
  4. mean 演算の入力ノード (node.args[0]) のメタデータから形状を取得し、mean ノードの出力形状を推論して meta 属性に設定します(ここでは最後の次元を削除)。データ型も入力から継承します。

注意点

  • torch.fx.Transformer などの API を利用して、グラフ変換とメタデータ伝播をより体系的に行うことができます。
  • PyTorch FX は、一部の基本的な演算に対して自動的な形状推論機能を持っていますが、複雑な演算やカスタム演算に対しては手動で推論ルールを記述する必要があります。
  • メタデータの推論ロジックは、各演算の種類に応じて適切に実装する必要があります。

torch.export を利用する (PyTorch 2.0 以降)

PyTorch 2.0 で導入された torch.export API は、モデルを様々な形式(TorchScript、ONNX など)にエクスポートする機能を提供しますが、その過程でモデルの形状やデータ型に関する情報が静的に解析されます。エクスポートされたグラフは、process_inputs() を明示的に呼ばなくても、入力の形状やデータ型に関する情報を持っている場合があります。

import torch
import torch.nn as nn
from torch.export import symbolic_trace

class SimpleModule(nn.Module):
    def forward(self, x):
        return x + 1

model = SimpleModule()

# symbolic_trace を使用してトレース
graph = symbolic_trace(model, example_inputs=(torch.randn(2, 3),))

# グラフの placeholder ノードを確認 (メタデータが含まれている可能性あり)
for node in graph.nodes:
    if node.op == 'placeholder':
        print(f"ノード名: {node.name}, メタデータ: {node.meta}")

# エクスポートされたグラフに対して更なる処理を行う
# (例: 最適化、コード生成など)

解説

  1. torch.export.symbolic_trace を使用してモデルをトレースします。この際、example_inputs として入力のサンプルを与えることで、トレーサーが入力の形状やデータ型を推論する手助けをします。
  2. 生成されたグラフの placeholder ノードの meta 属性には、example_inputs から推論された形状やデータ型が含まれている可能性があります。
  3. エクスポートされたグラフは、そのまま後続の処理(例えば、torch.fx.Optimizer を用いた最適化など)に利用できます。

注意点

  • エクスポートの目的が FX グラフの直接的な操作でない場合は、より高レベルな API として利用できます。
  • torch.export は、モデルの構造や演算によっては完全に静的な解析ができない場合があります。

torch.fx.Graph.process_inputs() は、FX グラフの入力メタデータを設定する基本的な方法ですが、以下の代替手段も存在します。

  • torch.export を利用する
    モデルのエクスポートを主な目的とする場合に、副次的にメタデータが得られることがあります。
  • グラフ変換中にメタデータを推論・伝播させる
    動的な形状変化に対応する場合や、より複雑なメタデータ管理が必要な場合。
  • 手動でメタデータを設定する
    静的な情報や追加情報を設定する場合。