PyTorch FX: python_code() で計算グラフをコード化する基本

2025-05-31

FXは、PyTorchのnn.Moduleインスタンスを変換するためのツールキットであり、主に以下の3つの主要なコンポーネントで構成されています。

  1. Symbolic Tracer (シンボリックトレーサー): PyTorchモデルのforwardメソッドのシンボリック実行を行い、その中で発生する演算を記録します。これにより、モデルの計算の流れをグラフとして表現します。
  2. Intermediate Representation (中間表現 - IR): シンボリックトレーシングによって記録された演算は、torch.fx.Graphオブジェクトとして表現されます。これは、Nodeのリストで構成されており、各Nodeは関数呼び出し、メソッド呼び出し、モジュール呼び出し、入出力などを表します。このIRに対して、モデルの最適化や変換が行われます。
  3. Python Code Generation (Pythonコード生成): 変換されたGraphから、元のモデルと等価なPythonコードを生成します。

この3番目のコンポーネントにおいて、torch.fx.Graph.python_code()が重要な役割を果たします。

torch.fx.Graph.python_code() の機能と目的

torch.fx.Graph.python_code()メソッドは、以下のような機能と目的を持っています。

  • 柔軟性: 生成されたPythonコードは、そのままファイルに保存して実行したり、exec()関数などを使って動的に実行したりすることが可能です。
  • プログラムによるモデル生成: 設定ファイルや他のソースからプログラム的にモデルを生成する際に、Graphを構築し、それをpython_code()で具体的なPythonコードに変換するという使い方も考えられます。
  • Python-to-Python (Module-to-Module) 変換: FXの大きな特徴の一つは、「PythonからPythonへ」の変換を可能にすることです。Graphオブジェクトに対して様々な最適化(演算の融合、量子化など)を適用した後、python_code()を使って変換後の新しいnn.Moduleインスタンスを生成するためのPythonコードを得ることができます。これにより、FXで変換されたモデルは、他の通常のPyTorchモジュールと同様に扱うことができます。
  • グラフの可視化とデバッグ: 抽象的な計算グラフを、人間が読みやすいPythonコードとして出力することで、モデルの動作を理解しやすくします。特に、モデル変換や最適化を行った後に、その変更がどのようにコードに反映されたかを確認するのに役立ちます。

基本的な使用の流れは以下のようになります。

  1. 既存のtorch.nn.Moduletorch.fx.symbolic_traceでトレースし、torch.fx.GraphModule(その中にGraphが含まれる)を取得します。
  2. 必要に応じて、取得したGraphを(GraphModule.graphを通じて)変更します。
  3. 変更後のGraphまたは元のGraphからpython_code()を呼び出して、Pythonコード文字列を取得します。
import torch
import torch.fx
import torch.nn as nn

# 例としてシンプルなPyTorchモデルを定義
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

# モデルのインスタンスを作成
model = MyModule()

# モデルをFXでシンボリックトレース
traced_model = torch.fx.symbolic_trace(model)

# トレースされたGraphを取得
graph = traced_model.graph

# GraphからPythonコードを生成
generated_code = graph.python_code()

print("--- 生成されたPythonコード ---")
print(generated_code)

# 生成されたコードを新しいモジュールとして実行することも可能(例)
# from types import FunctionType
# new_forward_fn = FunctionType(compile(generated_code, '<string>', 'exec'), globals())
# traced_model.forward = new_forward_fn
# print(traced_model(torch.randn(1, 10)))

上記のコードを実行すると、MyModuleforwardメソッドの計算グラフに対応するPythonコードが出力されます。このコードは、torch.fx.GraphModuleが内部でどのようにforwardメソッドを生成しているかを示しています。



torch.fx.symbolic_trace 自体の制限によるエラー

python_code() は、あくまで symbolic_trace によって生成された Graph オブジェクトをコードに変換するものです。そのため、そもそものトレースがうまくいかない場合に問題が発生します。

一般的なエラー

  • サポートされていない演算/型:

    • 原因: FXトレーサーは、すべてのPyTorch演算やPythonの型をグラフに変換できるわけではありません。特に、カスタムのPythonオブジェクトや、PyTorchの内部実装に深く関わるような低レベルな操作はトレースできないことがあります。
    • トラブルシューティング:
      • FXのドキュメント確認: 使用しているPyTorchのバージョンでFXがサポートしている演算やパターンを確認します。
      • カスタム演算のラップ: カスタムのPython関数をPyTorchのオペレーションとして登録するなどの方法を検討します。ただし、これは複雑な場合があります。
      • 別の方法の検討: FXによる変換が難しい場合は、TorchScriptなどの別の最適化ツールを検討するか、手動でモデルをリファクタリングすることを検討します。
  • Graph Break (グラフブレイク):

    • 原因: PyTorch FX (および torch.compile のバックエンドとしても使われるDynamo) は、モデルの計算グラフを静的に捉えようとします。しかし、Pythonの動的な機能(データ依存の制御フロー、一部のPython組み込み関数、torch 名前空間外の関数呼び出しなど)は、静的なグラフとして表現することが困難であり、そこで「グラフブレイク」が発生します。グラフブレイクが発生すると、その部分でグラフのトレースが中断され、最適化の機会が失われます。python_code() は、ブレイクした部分をPythonの通常の呼び出しとして出力しますが、期待通りの単一の最適化されたグラフにならない可能性があります。
    • :
      import torch
      import torch.nn as nn
      import torch.fx
      
      class MyModule(nn.Module):
          def forward(self, x):
              if x.sum() > 0: # データ依存の制御フロー
                  return x * 2
              else:
                  return x / 2
      
      model = MyModule()
      # ここで `symbolic_trace` がグラフブレイクを起こす可能性がある
      # もし `torch.compile` を使っているなら、より明確なエラーや警告が出る
      traced_model = torch.fx.symbolic_trace(model)
      graph = traced_model.graph
      print(graph.python_code())
      
    • トラブルシューティング:
      • データ依存の制御フローの回避: 可能な限り、モデルの構造が入力データに依存しないように設計します。例えば、if 文の条件にテンソルの値を直接使うのではなく、定数やテンソルのメタデータ(形状、次元数など)を使うようにします。
      • 対応するPyTorch関数への置き換え: Pythonの組み込み関数やNumPyなどの外部ライブラリの関数を使っている場合、それに対応するPyTorchの関数 (torch.mean, torch.sum など) があれば置き換えます。
      • torch._assert() の使用: assert 文を使いたい場合、torch._assert() はトレース可能なため、これを使用します。
      • torch.cond の利用: PyTorch 2.0以降では、条件分岐をグラフに含めるための torch.cond が導入されています。これにより、特定のデータ依存の制御フローをグラフ内で表現できるようになります。
      • torch.compile の利用: 現在では、FXを直接使うよりも torch.compile を使う方が推奨されます。torch.compile は内部でDynamoを利用しており、グラフブレイクをより賢くハンドリングし、問題箇所を特定するのに役立つ詳細なログや警告を提供します。

生成されたPythonコードに関するエラー

python_code() 自体はコード文字列を生成するだけなので、直接エラーを出すことは少ないですが、生成されたコードが期待通りでない場合や、そのコードを実行しようとしたときに問題が発生することがあります。

一般的なエラー

  • 型ヒントの欠落または不正確さ:

    • 原因: python_code() が生成するコードは、元のモデルが持っていた型ヒントを完全に保持しない場合があります。これは、主にFXの内部的な表現がPyTorchのテンソル操作に焦点を当てているためです。
    • トラブルシューティング:
      • 手動での追加: 生成されたコードを基に、必要に応じて手動で型ヒントを追加します。
      • 型チェッカーでの検証: mypy などの型チェッカーを使用して、型に関する問題を特定します。
  • 生成されたコードの解釈が難しい/バグがある:

    • 原因: FXが生成するコードは、最適化や内部表現の都合上、必ずしも人間にとって読みやすい形式ではありません。特に、元のモデルが複雑だったり、様々な変換が適用されたりすると、コードが複雑になることがあります。また、ごく稀に、FXのバグにより生成されたコードが意味的に正しくないことがあります。
    • トラブルシューティング:
      • FXグラフの検査: print(graph)graph.dump_graph() を使用して、Pythonコードを生成する前のGraphオブジェクトを直接検査します。グラフ構造が期待通りであれば、コード生成の問題である可能性が低くなります。
      • 最小限の再現可能な例: 問題を切り分けるために、できるだけシンプルなモデルで同様の問題が再現するか試します。
      • PyTorchのバージョンアップ: バグである場合、PyTorchの新しいバージョンで修正されている可能性があります。
      • 生成されたコードのデバッグ: 生成されたコードをファイルに保存し、通常のPythonコードとしてステップ実行するなどしてデバッグを試みます。

python_code() はコードを生成するだけで、直接的なパフォーマンス問題を引き起こすことはありませんが、生成されたグラフが最適化されていない(グラフブレイクが多いなど)場合、そのコードを基にしたモデルの実行時にパフォーマンスが低下することがあります。

一般的なエラー

  • 期待されたパフォーマンスが得られない:
    • 原因: 前述のグラフブレイクが多発している場合、FXによる最適化の恩恵を十分に受けられないため、期待するパフォーマンスが出ないことがあります。
    • トラブルシューティング:
      • グラフブレイクの特定と解消: torch.compile の詳細なログ (TORCH_LOGS="dynamo" python your_script.py など) を確認し、どこでグラフブレイクが発生しているかを特定し、対応策を講じます。
      • プロファイリング: torch.profiler などを使用して、モデルのボトルネックを特定し、グラフが生成されたコードが本当に遅いのか、他の部分に原因があるのかを切り分けます。
  • 状態の管理: GraphModule はモデルの状態(パラメータやバッファ)を保持します。python_code() で生成されたコードは、通常、これらの状態へのアクセス方法を含みますが、手動でコードを操作する場合は、状態の適切なロードと保存を考慮する必要があります。
  • 依存関係: 生成されたPythonコードを実行する環境に、元のモデルが依存していたライブラリ(torch.nn.functional など)が適切にインポートされていることを確認する必要があります。python_code() は通常、必要なインポート文を生成しますが、特殊なケースでは追加のインポートが必要になることがあります。


例1: 基本的な使用法 - シンプルなモデルのトレースとコード生成

最も基本的な例です。torch.nn.Module をトレースし、その計算グラフからPythonコードを生成します。

import torch
import torch.nn as nn
import torch.fx

# 1. シンプルなPyTorchモデルの定義
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

# 2. モデルのインスタンス化
model = SimpleModel()

# 3. モデルをFXでシンボリックトレースし、GraphModuleを取得
#    GraphModuleはGraphオブジェクトを内部に持っています
traced_model = torch.fx.symbolic_trace(model)

# 4. GraphModuleからGraphオブジェクトを取得
graph = traced_model.graph

# 5. GraphからPythonコードを生成
generated_code = graph.python_code()

print("--- 生成されたPythonコード ---")
print(generated_code)

# 生成されるコードの例(環境によって多少異なる場合があります):
# class GraphModule(torch.nn.Module):
#     def forward(self, x):
#         linear1 = self.linear1(x);  x = None
#         relu = self.relu(linear1);  linear1 = None
#         linear2 = self.linear2(relu);  relu = None
#         return linear2

解説: この例では、SimpleModelという簡単なニューラルネットワークを定義しています。torch.fx.symbolic_trace(model) によってこのモデルの forward メソッドがシンボリック実行され、その計算の流れが torch.fx.Graph オブジェクトとして記録されます。最終的に graph.python_code() を呼び出すことで、このグラフに対応するPythonコードが文字列として出力されます。このコードは、GraphModule というクラスと、その forward メソッドの内部実装を示しています。

例2: 生成されたコードの実行 - 新しいモジュールの作成

python_code() で生成されたコードは単なる文字列ですが、Pythonの exec() 関数などを使って動的に実行し、新しい nn.Module インスタンスを作成することができます。

import torch
import torch.nn as nn
import torch.fx
from types import FunctionType # 関数オブジェクトを動的に作成するために使用

# 例1と同じSimpleModelを再利用
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

model = SimpleModel()
traced_model = torch.fx.symbolic_trace(model)
graph = traced_model.graph
generated_code = graph.python_code()

print("--- 生成されたPythonコード ---")
print(generated_code)

# 1. 生成されたコードを文字列として取得
code_str = generated_code

# 2. 実行するためのグローバル/ローカルスコープを準備
#    torch, nn, fx など、生成されたコードが必要とするモジュールをインポートしておく
exec_globals = {
    'torch': torch,
    'nn': nn,
    'fx': torch.fx,
    'GraphModule': torch.fx.GraphModule # GraphModuleクラス自体も必要
}
exec_locals = {}

# 3. コードを動的に実行
#    これにより、exec_globals['GraphModule'] が新しいモデルの定義で上書きされる
#    (あるいは、別の名前で新しいクラスが定義される)
exec(code_str, exec_globals, exec_locals)

# 4. 新しく定義されたGraphModuleクラスを取得
#    通常、FXは GraphModule という名前でクラスを生成します
NewGraphModuleClass = exec_globals['GraphModule']

# 5. 元のモデルのパラメータとバッファを新しいモジュールにロード
#    traced_model.state_dict() には、元のモデルの重みが含まれています
new_model = NewGraphModuleClass(traced_model) # traced_model はGraphModuleなので、直接渡せる

print("\n--- 新しいモデルのインスタンス化と動作確認 ---")
dummy_input = torch.randn(1, 10)
output_original = model(dummy_input)
output_new = new_model(dummy_input)

print(f"元のモデルの出力:\n{output_original}")
print(f"新しいモデルの出力:\n{output_new}")
# 出力がほぼ同じになることを確認
print(f"出力の差の最大値: {(output_original - output_new).abs().max()}")

解説: この例では、exec() を使って生成されたPythonコードを実行し、新しい GraphModule クラスを定義しています。exec_globals には、生成されたコードが依存するモジュール(torch, nn, torch.fx など)を渡す必要があります。新しい GraphModuleClass をインスタンス化する際に、元の traced_model (これは GraphModule のインスタンス) を渡すことで、元のモデルのパラメータやバッファが新しいモジュールに正しくコピーされます。これにより、元のモデルと全く同じ動作をする新しいモジュールを動的に作成できることがわかります。

例3: Graphを変更した後のコード生成

FXの大きな利点は、Graphオブジェクトを直接操作して、最適化や変換を適用できることです。変更後に python_code() を使って、その変更が反映されたコードを出力できます。

import torch
import torch.nn as nn
import torch.fx

class CustomModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.param_a = nn.Parameter(torch.tensor(2.0))
        self.param_b = nn.Parameter(torch.tensor(0.5))

    def forward(self, x):
        # 意図的に非効率な計算(例)
        y = x * self.param_a
        z = y + self.param_b
        return z

model = CustomModule()
traced_model = torch.fx.symbolic_trace(model)
graph = traced_model.graph

print("--- 変更前の生成されたPythonコード ---")
print(graph.python_code())

# Graphの変更例: ノードの置き換え
# ここでは、例えば `add` 操作を別の操作に置き換えるなど(これは例です)
# 実際には、グラフ最適化のパス(Pass)を適用することが多い

# Graphをイテレートして、'add' オペレーションを探し、置き換える例
# ただし、直接ノードを変更するのは高度な操作であり、注意が必要です。
# 通常は、FXのPassesやTransformerなどを使って最適化を行います。

# ここでは非常にシンプルな変更として、最後の`call_function`ノードを
# `call_method`ノードに変換する例を挙げます。
# 通常は、もっと意味のある最適化(融合など)を行います。

# ノードの変更例(非常に単純な概念的な変更)
# 実際には、より複雑なロジックで特定のパターンを見つけ、最適化されたノードに置き換えます。
# 例として、最後の加算操作を乗算に「変更」してみる(意味的な変更ではありません)
# これはあくまで Graph の構造変更の可能性を示すためのものです。
for node in graph.nodes:
    if node.op == 'call_function' and node.target == torch.add:
        # この例では、最終的な出力ノードのオペレーションを強制的に変更する
        # 実際には、より複雑な条件や新しいノードの挿入を行います
        # この直接的な変更は、Graphが有効な状態を保つように慎重に行う必要があります
        pass # 今回は直接的なノード変更は行わず、概念を示すに留めます

# よりFXらしい Graph の変更例:
# 冗長な identity op を削除する、特定のパターンを融合するなど。
# 例として、GraphOptimizerを使って単純な最適化を適用します (FXの標準的な最適化パスではありませんが、概念を示します)
# 通常は、`torch.fx.passes` モジュールやカスタムのPassesを使用します。

# 簡単な最適化の例: ReLUとLinearの融合をシミュレートする(実際には複雑)
# この例では、Graphを直接操作するのではなく、簡潔に概念を示します。

# もし、GraphOptimizerでGraphを変更する処理を実装した場合、
# その変更がGraphオブジェクトに反映され、
# その後で python_code() を呼び出すと、変更後のコードが生成されます。

# 簡略化された例として、ダミーのノードを追加/削除する操作を仮定
# graph.erase_node(some_node)
# graph.insert_node(some_other_node)

# 変更後のGraphからPythonコードを生成
# (ここではGraph自体に大きな変更を加えていませんが、概念を示します)
graph.lint() # 変更後にGraphの整合性をチェックする (推奨)
modified_generated_code = graph.python_code()

print("\n--- 変更後の生成されたPythonコード (変更なしの場合、同じになる) ---")
print(modified_generated_code)

# 実際の Graph の変更例は複雑になるため、ここでは省略します。
# 重要なのは、Graphを変更した後、python_code() を呼び出すと、その変更が反映された
# 新しいPythonコードが得られるということです。

解説: FXの主な目的は、モデルのグラフを変換・最適化することです。この例では、graph オブジェクトを取得した後、もしグラフのノードを追加、削除、変更するような最適化ロジックを適用した場合、その後の graph.python_code() の呼び出しは、変更されたグラフに対応する新しいPythonコードを生成します。これにより、FXで最適化されたモデルを、通常のPyTorchモデルとしてPythonコード形式で配布・実行することが可能になります。

python_code() 自体がエラーを出すことは少ないですが、symbolic_trace の段階でグラフブレイクが発生すると、生成されるコードが期待通りでない場合があります。

import torch
import torch.nn as nn
import torch.fx

# データ依存の制御フローを持つモデル
class ConditionalModel(nn.Module):
    def forward(self, x):
        # グラフブレイクの原因となりうる条件分岐
        if x.mean() > 0:
            return x * 2
        else:
            return x / 2

model = ConditionalModel()

try:
    # グラフブレイクが発生すると、警告が出るか、torch.compile の場合はエラーになる
    # symbolic_trace は単体では警告に留まることが多い
    traced_model = torch.fx.symbolic_trace(model)
    graph = traced_model.graph
    generated_code = graph.python_code()

    print("--- グラフブレイクが発生したかもしれない生成コード ---")
    print(generated_code)
    print("\n注意: 上記コードでは、条件分岐がPythonのif文として残る可能性があります。")
    print("これは、FXがグラフ化できない動的な処理を、Pythonの呼び出しとして残すためです。")

except Exception as e:
    print(f"エラーが発生しました: {e}")
    print("これは、モデルのトレース中に問題があった可能性を示しています。")

解説: ConditionalModelforward メソッドには、入力テンソルの値に依存する if 文があります。このような「データ依存の制御フロー」は、FXが静的な計算グラフとしてトレースすることが難しいため、グラフブレイク(Graph Break)の原因となります。

symbolic_trace は、このような場合でも部分的にグラフを生成しようとしますが、条件分岐のロジック自体はPythonのネイティブな if 文として出力されることが多いです。これにより、FXによるコンパイル時の最適化が、その条件分岐をまたいで適用できなくなります。

トラブルシューティングのヒント:

  • 現在では、torch.compile(model) を使う方が推奨されます。torch.compile は内部でFX (Dynamo を介して) を利用し、グラフブレイクが発生した場合により詳細な情報や警告、またはエラーを提示し、問題の特定と修正を支援します。
  • このようなケースでは、可能であればデータに依存しない制御フローに書き換えるか、torch.cond のようなFX互換の構造を使用します。


主な代替手段

torch.compile() の利用 (最も推奨される現代的な方法)

  • いつ使うか: PyTorchモデルのパフォーマンスを向上させたい場合、真っ先に検討すべき方法です。
  • 利点:
    • 最も簡単で効果的なパフォーマンス最適化。
    • グラフブレイクの自動処理や、より良いデバッグ情報の提供。
    • 多数のバックエンド(inductor, aot_eager など)をサポート。
  • python_code() との関連: torch.compile() は、内部的に生成されたグラフや最適化されたコード(Triton/C++など)をデバッグ目的で出力するオプションを提供することがありますが、それはFXの python_code() が出力するようなPyTorchモジュールのPythonコードとは異なります。torch.compile() の主な目的は「実行時の高速化」であり、「人間が読めるPythonコードの生成」ではありません。
  • 説明: PyTorch 2.0 で導入された torch.compile() は、FXをバックエンドの主要な一つとして利用しており、ユーザーが直接FXのAPI(symbolic_tracepython_code など)を操作することなく、モデルを高速化するための高レベルなインターフェースを提供します。
    • torch.compile() は、内部でモデルの計算グラフを自動的に抽出し(主に torch.fx.Dynamo を使用)、そのグラフに対して様々な最適化(演算融合、メモリ最適化など)を適用し、最終的に生成されたコード(C++/Tritonなど)でモデルを実行します。
    • ユーザーは model = torch.compile(model) と書くだけでよく、python_code() のように明示的にPythonコードを生成して扱う必要がありません。
  • 目的: パフォーマンス最適化。
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = MyModel()

# torch.compile を適用
compiled_model = torch.compile(model)

# コンパイルされたモデルを実行 (初回実行時にコンパイルが行われる)
input_tensor = torch.randn(1, 10)
output = compiled_model(input_tensor)
print(f"Compiled model output: {output}")

# ここでは、`python_code()` のように明示的なPythonコードは生成されません。
# 内部的に最適化されたコードが使われます。

TorchScript の利用

  • いつ使うか:
    • サーバーサイドのC++アプリケーションでPyTorchモデルを実行したい場合。
    • モバイルデバイスなど、Python環境がない場所でモデルをデプロイしたい場合。
    • モデルを永続的に保存し、Pythonスクリプトなしでロード・実行したい場合。
  • 利点:
    • C++環境での実行(TorchScript JIT)。
    • モデルのシリアライズとデプロイの容易さ。
    • Pythonインタープリタのオーバーヘッドを削減。
  • python_code() との関連: TorchScript は python_code() のようにPythonコードを生成するのではなく、独自のIR(Intermediate Representation)を生成し、それをバイナリ形式で保存します(.pt ファイルなど)。これはPythonから切り離されて実行可能です。
  • 説明: TorchScript は、PyTorchモデルを静的なグラフ表現(IR)に変換するための古い方法です。JITコンパイラを使用して、Pythonの実行を回避し、モデルを独立した形式で保存・ロードできるようにします。
    • torch.jit.trace(): 入力例を与えて動的にトレースします。
    • torch.jit.script(): モデルのソースコードを解析してスクリプト化します。
  • 目的: モデルのシリアライズ、デプロイ、C++環境での実行。
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
input_tensor = torch.randn(1, 10)

# TorchScriptでトレース
scripted_model = torch.jit.trace(model, input_tensor)

# モデルを保存
scripted_model.save("my_scripted_model.pt")

print("TorchScript モデルを 'my_scripted_model.pt' に保存しました。")

# 保存したモデルをロードして実行
loaded_model = torch.jit.load("my_scripted_model.pt")
output = loaded_model(input_tensor)
print(f"Loaded scripted model output: {output}")

# ここでは、FXのpython_code() のようなPythonコードの出力はありません。
# モデルは TorchScript 形式の内部IRとして保存されます。

ONNX へのエクスポート

  • いつ使うか:
    • PyTorch以外の環境でモデルを実行する必要がある場合。
    • ONNX Runtimeなど、特定のONNX互換の高速推論エンジンを利用したい場合。
  • 利点:
    • フレームワーク間のモデル移行。
    • 特定のハードウェア(NVIDIA TensorRT, Intel OpenVINOなど)に最適化されたランタイムの利用。
    • モデルのデプロイメントの柔軟性。
  • python_code() との関連: ONNXもまた、Pythonコードを生成するのではなく、モデルの計算グラフを独自のバイナリ形式で表現します。
  • 説明: ONNX (Open Neural Network Exchange) は、深層学習モデルを表現するためのオープンスタンダードです。PyTorchモデルをONNX形式にエクスポートすることで、TensorFlow, ONNX Runtime, OpenVINO など、ONNXをサポートする他のフレームワークやランタイムでモデルを実行できるようになります。
  • 目的: 異なる深層学習フレームワーク間でのモデルの相互運用性。
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
input_tensor = torch.randn(1, 10)

# ONNX形式でエクスポート
torch.onnx.export(model,               # モデル
                  input_tensor,        # ダミー入力 (グラフ構築用)
                  "my_model.onnx",     # 保存パス
                  export_params=True,  # モデルのパラメータもエクスポート
                  opset_version=14,    # ONNX Opset バージョン
                  do_constant_folding=True, # 定数畳み込み
                  input_names = ['input'],   # 入力ノードの名前
                  output_names = ['output'], # 出力ノードの名前
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 動的バッチサイズ
                                'output' : {0 : 'batch_size'}})

print("ONNX モデルを 'my_model.onnx' に保存しました。")

# このファイルをONNX Runtimeなどでロードして実行できます。
# ここでも、FXのpython_code() のようなPythonコードの出力はありません。

python_code() が「人間が読めるPythonコード」を提供するのに対し、FXの内部的なグラフ表現を直接調べることで、より詳細な情報や、コード生成では得られない構造を確認できます。

torch.fx.Graph オブジェクトの直接検査

  • 利点:
    • python_code() では抽象化されてしまう内部的なグラフ構造を直接確認できる。
    • グラフ変換や最適化のデバッグに役立つ。
    • カスタムのFXパスを開発する際に必須。
  • 説明: python_code() を呼び出す前に、graph オブジェクト自体を直接調べて、そのノード(Node)やエッジ(依存関係)を確認できます。
    • print(graph): 簡潔なグラフのテキスト表現を出力します。
    • for node in graph.nodes:: 各ノードをイテレートして、その操作(node.op)、ターゲット(node.target)、引数(node.args)、キーワード引数(node.kwargs)、出力ノード(node.users)などを詳細に調べることができます。
    • graph.lint(): グラフの整合性をチェックし、潜在的な問題を早期に発見します。
  • 目的: グラフ構造のデバッグ、カスタムパスの作成。
import torch
import torch.nn as nn
import torch.fx

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        return x

model = SimpleModel()
traced_model = torch.fx.symbolic_trace(model)
graph = traced_model.graph

print("--- Graph オブジェクトの簡易表示 ---")
print(graph)

print("\n--- Graph ノードの詳細検査 ---")
for node in graph.nodes:
    print(f"Node: {node.name}")
    print(f"  Op: {node.op}")
    print(f"  Target: {node.target}")
    print(f"  Args: {node.args}")
    print(f"  Kwargs: {node.kwargs}")
    print(f"  Users (誰がこのノードを使うか): {[u.name for u in node.users]}")
    print("-" * 20)

# lint を実行してグラフの整合性を確認
graph.lint()
print("\nGraph lint (整合性チェック) が完了しました。問題はありません。")

torch.fx.GraphModule の利用

  • 利点:
    • python_code() を介してコードを生成し、exec() で実行するよりも直接的で簡単。
    • nn.Module と同じように扱えるため、既存のPyTorchの訓練・推論ループに簡単に組み込める。
    • モデルのパラメータやバッファも自動的に含まれる。
  • 説明: torch.fx.symbolic_trace()GraphModule のインスタンスを返します。この GraphModule は、内部に Graph オブジェクトを持ち、そのグラフを実行可能な nn.Module としてラップしたものです。
    • python_code() は、この GraphModuleforward メソッドに相当するコードを生成します。しかし、ほとんどの場合、GraphModule オブジェクト自体を直接使うことができます。
  • 目的: グラフから実行可能なPyTorchモジュールを直接作成。
import torch
import torch.nn as nn
import torch.fx

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
input_tensor = torch.randn(1, 10)

# traced_model は直接 GraphModule のインスタンス
traced_model = torch.fx.symbolic_trace(model)

print("--- GraphModule の直接実行 ---")
output_traced = traced_model(input_tensor)
output_original = model(input_tensor)

print(f"Original model output: {output_original}")
print(f"Traced model output: {output_traced}")
print(f"Outputs are close: {torch.allclose(output_original, output_traced)}")

# python_code() は、この traced_model の forward メソッドに相当するコードを生成します
# しかし、コードを生成してexec()するよりも、traced_model そのものを使う方が簡単です。
# print(traced_model.graph.python_code())

torch.fx.Graph.python_code() は、FXの内部動作を理解したり、特殊なデバッグシナリオでグラフのPython表現を確認したりするのに非常に役立ちます。しかし、PyTorchモデルのパフォーマンス最適化、デプロイ、異なるフレームワークとの相互運用性といった一般的な目的のためには、以下の代替手段がより適切で推奨されます。

  • FXグラフのデバッグ・カスタムパス開発: Graph オブジェクトの直接検査 (print(graph), node.op など) と GraphModule の直接利用。
  • 異なるフレームワークへのエクスポート: ONNX
  • シリアライズ・デプロイ(C++環境など): TorchScript
  • パフォーマンス最適化: torch.compile()