PyTorch `torch.fx.Node`プログラミング入門:計算グラフの理解と操作

2025-05-31

torch.fxは、PyTorchのプログラムをシンボリックにトレースし、グラフ形式で表現するためのモジュールです。このfxモジュールが構築するグラフの基本的な構成要素が「torch.fx.Node」です。

簡単に言うと、torch.fx.Nodeは、トレースされたPyTorchモデル内の個々の操作(関数呼び出し、メソッド呼び出し、モジュールの利用など)やデータの流れを表すオブジェクトです。

torch.fx.Nodeの主な特徴と役割:

  1. グラフの構成要素: torch.fxはPyTorchモデルの実行をトレースして、計算グラフを構築します。このグラフは一連のNodeオブジェクトで構成され、それぞれがモデル内の特定のステップを表します。

  2. 操作の種類 (OpCode): 各Nodeは、それが表す操作の種類を示す「opcode」属性を持ちます。主なopcodeには以下のようなものがあります。

    • placeholder: モデルへの入力(引数)を表します。
    • call_function: torch.addのようなPythonの関数呼び出しを表します。
    • call_method: テンソルオブジェクトの.relu()のようなメソッド呼び出しを表します。
    • call_module: nn.LinearのようなPyTorchのnn.Moduleの呼び出しを表します。
    • get_attr: モデルの属性(例えば、self.param)の取得を表します。
    • output: モデルの最終的な出力を表します。
  3. 引数とターゲット:

    • args: そのNodeが表す操作に渡される引数(位置引数)のタプルです。これらの引数は、他のNodeオブジェクトへの参照である場合が多いです。これは、計算グラフにおけるデータの依存関係を示します。
    • kwargs: そのNodeが表す操作に渡されるキーワード引数の辞書です。
    • target: そのNodeが実行する具体的な操作(関数、メソッド、モジュールなど)への参照です。例えば、call_functionなら関数オブジェクト、call_moduleならモジュール名(文字列)などが入ります。
  4. 依存関係の表現: Nodeオブジェクトのargskwargsの中に別のNodeオブジェクトが含まれることで、計算グラフにおける入力と出力の依存関係が表現されます。これにより、データの流れや操作の順序を追跡できます。

  5. メタデータ: Nodeは、デバッグや最適化に役立つ追加のメタデータ(例: meta属性)を保持することもできます。

torch.fx.Nodeが使われる場面:

  • 高レベルのコンパイラ構築: PyTorchモデルをターゲットとするコンパイラを構築する際の、中間表現(IR)として利用されます。
  • プログラムの解析: モデルの構造やデータフローを視覚化したり、特定の操作を特定したりするために使用されます。
  • モデルの変換と最適化: fxでモデルをグラフ表現に変換することで、特定のパターンを検出して最適化(例: 演算融合、枝刈りなど)を適用したり、異なるハードウェア向けに変換したりすることが容易になります。

簡単な例:

import torch
import torch.fx

class MyModel(torch.nn.Module):
    def forward(self, x, y):
        a = x + y
        b = a.relu()
        c = b * 2
        return c

# モデルをトレースしてグラフを取得
traced_model = torch.fx.symbolic_trace(MyModel())
graph = traced_model.graph

# グラフ内のノードをイテレート
print("--- Graph Nodes ---")
for node in graph.nodes:
    print(f"Node: {node.name}, OpCode: {node.opcode}, Target: {node.target}, Args: {node.args}, Kwargs: {node.kwargs}")


torch.fxは、Pythonのコードをシンボリックにトレースして計算グラフを構築します。この「トレース」という性質が、エラーの主な原因となることが多いです。

トレースできないPythonの構文/動的な処理

エラーの例:

  • 無限ループ、または予期しない挙動。
  • AssertionError: SymPy expression cannot be converted to a Python value... (特に動的な形状を扱う場合)
  • RuntimeError: torch.fx cannot trace this operation: ...

原因: torch.fx.symbolic_traceは、モデルの実行パスを静的に解析してグラフを構築します。以下のPythonの構文やパターンは、静的なグラフとして表現することが難しいため、トレースに失敗したり、予期しない結果になったりすることがあります。

  • listdictなどのPythonコンテナの動的な変更: torch.Tensor以外のPythonオブジェクトが複雑に操作されると、トレースが難しくなります。特に、torch.catなどにlistを渡す場合、FXはデフォルトでlistではなく単一のTensorが来ると仮定してエラーになることがあります。
  • 可変長の引数(*args, **kwargs)の複雑な使用: forwardメソッドの引数が非常に動的である場合、トレースが困難になることがあります。
  • assert: assert文はトレース中にエラーを引き起こす可能性があります。torch._assert(非公開APIのため注意が必要)に置き換える必要がある場合があります。
  • 組み込み関数や外部ライブラリの利用: torch名前空間外の多くのPython組み込み関数や、NumPyなどの外部ライブラリの関数をモデルのforward内で直接使用すると、トレースに失敗することがあります。
  • インプレース操作(In-place operations): テンソルをインプレースで変更する操作(例: x.add_(y)x[:, 0] = ...)は、FXのトレースでは正しく扱えない場合があります。特にスライスへのインプレース変更はサポートされていません。
  • 動的な制御フロー: if文、forループ、whileループなどが、入力テンソルの値に依存して変化する場合。例えば、バッチサイズによってループ回数が変わる、といったケースです。

トラブルシューティング:

  • エラーメッセージの確認: エラーメッセージには、トレースに失敗した具体的な操作や場所が示されていることが多いので、詳細を確認します。
  • printデバッグ: graph.print_tabular()でグラフの内容を確認したり、node.name, node.opcode, node.target, node.args, node.kwargsなどを出力して、意図したようにノードが生成されているか確認します。
  • カスタムTracerの利用: より複雑なケースでは、torch.fx.Tracerを継承して、特定の操作のトレース方法をカスタマイズする必要があるかもしれません。
  • コードのリファクタリング:
    • 動的な制御フローを避け、静的なグラフに変換できるような構造に変更します。
    • インプレース操作を避け、新しいテンソルを作成する操作(例: x = x + y)を使用します。
    • PyTorchの組み込み関数やtorch.nn.functionalにある関数を使用するように変更します。
    • 複雑なリストや辞書の操作を、テンソル操作で代替できないか検討します。
  • torch.compileの利用を検討する: PyTorch 2.0以降のtorch.compileは、torch.fxのバックエンドとしてTorchDynamoを使用しており、より高度なトレース(動的な制御フローの一部も処理可能)を提供します。torch.fx.symbolic_traceで問題が発生する場合は、まずtorch.compileへの移行を検討してください。

GraphModuleの再コンパイル/実行時のエラー

エラーの例:

  • RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) mismatch (デバイス不一致)
  • AttributeError: 'GraphModule' object has no attribute '...'
  • TypeError: 'NoneType' object is not callable

原因: torch.fx.Nodeを直接操作してグラフを変換した後、GraphModuleを再構築し、それを実行する際に問題が発生することがあります。

  • デバイスの不一致: モデル全体または一部のテンソルが異なるデバイス(CPUとGPU)に存在する場合、実行時にエラーになります。
  • ターゲットの不一致: call_functionノードのtargetが実際に関数オブジェクトでない、call_moduleノードのtargetがモジュール名として正しくない、など。
  • ノードの追加/削除/変更が不適切: グラフの構造を変更した際に、ノード間の依存関係が壊れたり、存在しないノードを参照しようとしたりする。

トラブルシューティング:

  • デバイスの統一: モデルと入力テンソルを同じデバイス(例: model.to('cuda'), input.to('cuda'))に配置するようにします。
  • デバッグ用インタープリタの利用: torch.fx.Interpreterを継承したカスタムインタープリタを作成し、runメソッドをオーバーライドすることで、グラフの実行中に各ノードの入力や出力、属性などを詳細にデバッグできます。
  • ノードの依存関係の確認: node.usersnode.all_input_nodesプロパティを使って、ノード間の依存関係が正しく構築されているかを確認します。特にnode.replace_all_uses_with(new_node)を使用する際は、new_nodeが自分自身を引数として参照してしまわないように注意が必要です(copy.deepcopyを使用するなど)。
  • GraphModuleの確認: 変換後のGraphModuleのコードをprint(gm.code)で出力し、期待通りのPythonコードが生成されているか確認します。また、gm.graph.print_tabular()でノード一覧を再度確認します。

グラフの生成は成功するが、結果が不正確/パフォーマンスが出ない

原因:

  • 最適化の不足: グラフが生成されたとしても、そのグラフが特定の最適化に適した形になっていない場合があります。
  • サイレントな非対応: FXが一部の操作を正しくトレースできず、期待通りのグラフが生成されないが、エラーは発生しない「サイレントな失敗」が発生することがあります。特に、カスタムのPythonオブジェクトや複雑なデータ構造を扱う場合に起こりやすいです。

トラブルシューティング:

  • ベンチマーク: パフォーマンスの目標がある場合、元のモデルとGraphModuleのベンチマークを行い、期待される改善が得られているかを確認します。
  • torch.compilefullgraph=True: torch.compileを使用している場合、fullgraph=Trueを設定すると、モデル全体が単一のグラフとしてコンパイルできない場合にエラーを発生させます。これにより、一部がEager-Modeにフォールバックしてしまっている「サイレントな失敗」を検出できます。
  • グラフの構造の検査: graph.print_tabular()やGraphvizなどのツールを使ってグラフを可視化し、モデルの構造が正しく表現されているか、意図しないノードが存在しないかなどを視覚的に確認します。
  • 元のモデルとの結果比較: トレースされたGraphModuleで推論を実行し、元のモデルの出力と厳密に一致するかを確認します。


基本的な流れ

  1. モデルの定義: torch.nn.Moduleを継承した通常のPyTorchモデルを定義します。
  2. シンボリックトレース: torch.fx.symbolic_trace()を使って、モデルをグラフ表現に変換し、GraphModuleインスタンスを取得します。GraphModuleは、内部にtorch.fx.Graphオブジェクトを持っています。
  3. ノードの走査: graph.nodesをイテレートして、個々のNodeオブジェクトにアクセスします。
  4. ノードのプロパティの確認: 各Nodeopcodetargetargskwargsなどのプロパティを確認します。
  5. ノードの操作(オプション): 必要に応じて、ノードの追加、削除、置き換えなどの操作を行います。
  6. GraphModuleの再コンパイル(操作した場合): グラフに変更を加えた場合は、graph_module.recompile()を呼び出して、変更を反映した新しいforwardメソッドを生成します。
  7. GraphModuleの実行: 変換されたGraphModuleを使って推論を実行します。

例1:モデルのグラフ表現とノードの情報の表示

この例では、シンプルなモデルをトレースし、生成されたグラフ内の各ノードの情報を表示します。

import torch
import torch.nn as nn
import torch.fx
import operator # torch.fx は Python の組み込み演算子もトレースします

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.register_buffer('my_buffer', torch.tensor([0.5])) # バッファもトレースされる

    def forward(self, x, y):
        x = self.linear1(x)
        x = self.relu(x)
        z = x + y + self.my_buffer # 組み込み演算子とバッファアクセス
        return z

# モデルのインスタンス化
model = SimpleModel()

# モデルをシンボリックトレース
# symbolic_traceはGraphModuleを返します
traced_model = torch.fx.symbolic_trace(model)

print("--- 生成された GraphModule のコード ---")
print(traced_model.code)
print("\n--- グラフの表形式表示 ---")
traced_model.graph.print_tabular()

print("\n--- 各ノードの詳細 ---")
for node in traced_model.graph.nodes:
    print(f"Node Name: {node.name}")
    print(f"  OpCode: {node.opcode}")
    print(f"  Target: {node.target}")
    print(f"  Args: {node.args}")
    print(f"  Kwargs: {node.kwargs}")
    print("-" * 20)

# 変換されたモデルの実行
dummy_x = torch.randn(1, 10)
dummy_y = torch.randn(1, 20)
output_traced = traced_model(dummy_x, dummy_y)
output_original = model(dummy_x, dummy_y)

print(f"\n元のモデルの出力: {output_original}")
print(f"トレースされたモデルの出力: {output_traced}")
assert torch.allclose(output_original, output_traced)
print("出力は一致しています。")

解説:

  • for node in traced_model.graph.nodes:: グラフ内の各Nodeオブジェクトを順番に取得し、その属性にアクセスしています。
    • node.opcode: そのノードが表す操作の種類(例: placeholder, call_module, call_function, get_attr, output)。
    • node.target: そのノードが呼び出す実際のオブジェクト(関数、モジュール名、メソッド名など)。
    • node.args, node.kwargs: その操作に渡される引数。他のNodeオブジェクトを参照している場合、それはデータの依存関係を示します。
  • traced_model.graph.print_tabular(): グラフのノードをテーブル形式で分かりやすく表示します。opcodenametargetargskwargsが列挙されます。
  • traced_model.code: GraphModuleが内部的に保持するPythonコード(最適化されたforwardメソッド)を表示します。
  • torch.fx.symbolic_trace(model): SimpleModelforwardメソッドの実行をトレースし、GraphModuleオブジェクトを生成します。

例2:特定のノードの置き換え

この例では、グラフ内のReLUノードをLeakyReLUノードに置き換える方法を示します。

import torch
import torch.nn as nn
import torch.fx
import operator

class ModelWithReLU(nn.Module):
    def forward(self, x):
        x = nn.Linear(10, 10)(x)
        x = torch.relu(x) # 関数としてのReLU
        return x

model = ModelWithReLU()
traced_model = torch.fx.symbolic_trace(model)

print("--- 変更前のグラフ ---")
traced_model.graph.print_tabular()

# グラフを走査し、ReLUノードをLeakyReLUに置き換える
for node in traced_model.graph.nodes:
    # torch.relu 関数を探す
    if node.op == 'call_function' and node.target == torch.relu:
        print(f"ReLUノード '{node.name}' をLeakyReLUに置き換えます。")

        # 新しいノードを作成する挿入ポイントを設定
        # ここでは元のノードの直後に挿入する
        with traced_model.graph.inserting_after(node):
            # 新しいLeakyReLUノードを作成。引数は元のReLUノードと同じにする
            new_node = traced_model.graph.call_function(
                torch.nn.functional.leaky_relu,
                node.args,
                node.kwargs
            )
        
        # 元のノードを使用していた全ての場所を新しいノードに置き換える
        node.replace_all_uses_with(new_node)
        
        # 不要になった元のノードをグラフから削除する
        traced_model.graph.erase_node(node)

# グラフに変更を加えたので、GraphModuleを再コンパイルする
traced_model.recompile()

print("\n--- 変更後のグラフ ---")
traced_model.graph.print_tabular()
print("\n--- 変更後の GraphModule のコード ---")
print(traced_model.code)

# 変換されたモデルの実行と確認
dummy_input = torch.randn(1, 10)
output_traced = traced_model(dummy_input)
print(f"\n変換後のモデルの出力の形状: {output_traced.shape}")

解説:

  • traced_model.recompile(): グラフの構造を変更した後は、必ずこれを呼び出す必要があります。これにより、GraphModuleforwardメソッドが、変更されたグラフに基づいて再生成されます。
  • traced_model.graph.erase_node(node): 不要になった元のnodeをグラフから削除します。
  • node.replace_all_uses_with(new_node): これが非常に重要なメソッドです。nodeの出力を使用していたグラフ内のすべてのノードが、代わりにnew_nodeの出力を参照するように変更されます。これにより、グラフの接続性が保たれます。
  • traced_model.graph.call_function(target, args, kwargs): call_functionタイプの新しいNodeを作成します。targettorch.nn.functional.leaky_reluargskwargsは元のreluノードからコピーしています。
  • traced_model.graph.inserting_after(node): 新しいノードを現在のnodeの直後に挿入するためのコンテキストマネージャです。

torchvision.models.feature_extraction.create_feature_extractorは、torch.fxのノード操作を利用した典型的な例です。ここではその内部で行われているような、特定の中間層の出力を取り出すシンプルな例を示します。

import torch
import torch.nn as nn
import torch.fx

class FeatureExtractionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        feat1 = self.pool1(x) # 抽出したい中間特徴量1

        x = self.conv2(feat1) # feat1が次の層への入力になる
        x = self.bn2(x)
        x = self.relu2(x)
        feat2 = self.pool2(x) # 抽出したい中間特徴量2

        x = self.avgpool(feat2)
        x = torch.flatten(x, 1)
        output = self.fc(x)
        return output, feat1, feat2 # 複数の出力を返すように変更

model = FeatureExtractionModel()
traced_model = torch.fx.symbolic_trace(model)

# 中間特徴量のノードを見つける
# ノードの名前は、モジュールの名前または操作名に基づいて自動生成される
# 例えば、`self.pool1`の出力は `pool1` という名前のノードになる
target_node_names = ['pool1', 'pool2']
output_nodes = []

for node in traced_model.graph.nodes:
    if node.name in target_node_names:
        output_nodes.append(node)
    
    # 既存の出力ノードを探す (最後の output ノード)
    if node.opcode == 'output':
        output_node = node

# 既存の output ノードの args を変更して、中間特徴量を追加する
# output ノードの target は常に None
# output ノードの args の最初の要素が出力値
# output ノードの args を新しい出力のタプルで置き換える
current_outputs = output_node.args[0] # 現在の出力
if not isinstance(current_outputs, tuple):
    current_outputs = (current_outputs,) # タプルでない場合はタプルに変換

# 新しい出力のリストを作成
new_outputs = list(current_outputs) + output_nodes

# output ノードの args を更新
output_node.args = (tuple(new_outputs),)

# グラフに変更を加えたので、GraphModuleを再コンパイルする
traced_model.recompile()

print("--- 変更後の GraphModule のコード ---")
print(traced_model.code)

# 変換されたモデルの実行
dummy_input = torch.randn(1, 1, 28, 28)
final_output, feat1_output, feat2_output = traced_model(dummy_input)

print(f"\n最終出力の形状: {final_output.shape}")
print(f"中間特徴量1 (pool1) の形状: {feat1_output.shape}")
print(f"中間特徴量2 (pool2) の形状: {feat2_output.shape}")

解説:

  • 注意点として、ノードの名前はtorch.fxが自動的に生成するため、特定のノードを見つけるには、モジュールの名前(self.pool1ならpool1)や関数のターゲット(torch.reluなど)に依存することになります。
  • output_node.args = (tuple(new_outputs),): outputノードのargsは、そのグラフの最終的な出力のタプル(または単一の値)です。ここに、元々の出力に加えて、抽出したい中間特徴量のノードを追加しています。
  • この例では、forwardメソッドのreturn文を変更して、中間層の出力も返すようにoutputノードを直接操作しています。


    • 説明: torch.fx.Nodeを直接操作する代わりに、torch.fx.passesはより高レベルのAPIや、一般的な最適化パスを提供します。これらは内部的にtorch.fx.Nodeを操作しますが、ユーザーが直接ノードを扱う必要がないように抽象化されています。例えば、モジュール融合(Module Fusion)やデータ並列化のための変換などが含まれることがあります。
    • 利点:
      • 特定の最適化パスの実装が簡素化される。
      • 一般的な変換がすでに提供されている場合、再発明の必要がない。
      • 低レベルのノード操作の詳細から解放される。
    • 欠点:
      • 提供されていないカスタムな変換を行うには、やはりFXの低レベルな知識が必要になる場合がある。
    • 使用例: 特定の演算子を融合する、モデルを量子化に適した形に変換する、など。
  1. torch.compile (PyTorch 2.0+ の推奨される最適化)

    • 説明: PyTorch 2.0以降で導入されたtorch.compileは、モデルの実行を高速化するための高レベルなAPIです。内部的には、TorchDynamoがPythonコードをキャプチャし、torch.fxグラフを生成し、そのグラフをAOTAutogradや他のバックエンド(例: Inductor)に渡して最適化されたコードを生成します。ユーザーはtorch.fx.Nodeを意識することなく、モデルの高速化を享受できます。
    • 利点:
      • 最も推奨されるパフォーマンス最適化手法。
      • ほとんどのモデルでコード変更なしに高速化が可能。
      • 内部的にFXが使われるため、FXの恩恵を受けつつ、低レベルな複雑さを回避できる。
      • 動的な制御フローもある程度処理できる。
    • 欠点:
      • 特定の非常に複雑な動的制御フローや、サポートされていない操作では「フォールバック」(Eagerモードに戻ること)が発生し、最適化の恩恵が限定されることがある。
      • デバッグが難しい場合がある。
    • 使用例: モデルのトレーニングや推論を高速化したい場合。
  2. torch.jit.script / torch.jit.trace (TorchScript)

    • 説明: TorchScriptは、PyTorchモデルをPythonインタープリタに依存しない形で実行するためのIRです。
      • torch.jit.script: Pythonコードを解析して、TorchScript IRを構築します。これにより、静的な制御フロー(if, forなど)もグラフとして表現できます。
      • torch.jit.trace: サンプル入力を使ってモデルを実行し、実行パスを記録してTorchScript IRを構築します。動的な制御フローはサポートされません。
    • 利点:
      • モデルのデプロイ(C++、モバイル、エッジデバイスなど)に非常に有用。
      • JITコンパイルによる実行高速化の可能性。
      • モデルのエクスポート形式として広く使われている。
    • 欠点:
      • torch.fxと比較して、グラフ変換や最適化の柔軟性が低い。ノードレベルの操作はFXほど直接的ではない。
      • traceは動的な制御フローを扱えない。
      • scriptは一部のPython機能に制限がある。
    • 使用例: モデルをデプロイしたい場合、Pythonの外部でモデルを実行したい場合。
  3. ONNX (Open Neural Network Exchange)

    • 説明: ONNXは、ニューラルネットワークモデルを表現するためのオープンスタンダード形式です。PyTorchモデルをONNX形式にエクスポートし、ONNXランタイム(ONNX Runtime)や他のONNX互換のフレームワークで実行できます。ONNXもモデルを計算グラフとして表現します。
    • 利点:
      • 異なるフレームワーク間でのモデルの相互運用性を実現。
      • ONNXランタイムなどの最適化された実行環境が利用可能。
    • 欠点:
      • PyTorchのtorch.fxほど低レベルなモデル操作の柔軟性はない。ONNXグラフの操作は通常、特定のONNXツールやライブラリで行われる。
      • PyTorch固有の操作がONNXに直接マッピングできない場合がある。
    • 使用例: TensorFlow、Caffe2などの他のフレームワークとの間でモデルをやり取りする場合、クロスプラットフォームなデプロイ。
  4. カスタムのモデル変換スクリプト (PyTorchモジュールレベル)

    • 説明: torch.nn.Moduleオブジェクトとそのサブモジュールを直接操作するスクリプトを書く方法です。例えば、model.children()model.named_modules()を使ってモデルのレイヤーをイテレートし、特定のレイヤーを別のレイヤーに置き換えるなどです。
    • 利点:
      • FXやTorchScriptの複雑さを導入する必要がない。
      • PyTorchのAPIに慣れていれば直感的に記述できる。
    • 欠点:
      • レイヤー間の「接続」(データの流れ)を自動的に追跡できないため、手動で管理する必要がある。
      • 特定の操作(例: torch.addのような関数呼び出し)はモジュールとして存在しないため、この方法では直接操作できない。
      • 複雑なグラフ変換には不向き。
    • 使用例: 特定のレイヤーの入れ替え、不要なレイヤーの削除、モデルの量子化準備のための特定のモジュールへの変換など、比較的単純なモジュールレベルの変更。
代替方法目的と特徴torch.fx.Node操作との関係
torch.fx.passesFXグラフ上の高レベルな変換/最適化パスを提供。内部的にFXノードを操作するが、ユーザーは直接ノードを扱わない。
torch.compilePyTorchモデルの高速化のための推奨API。内部的にTorchDynamoとFXを利用してグラフを生成・最適化するが、ユーザーはFXを意識する必要がない。
torch.jit.script / traceモデルのデプロイ、C++/モバイル環境での実行。グラフをIRとして表現するが、FXほど柔軟なグラフ操作はできない。
ONNX異なるフレームワーク間のモデル相互運用性。モデルをグラフとして表現するが、PyTorch固有のノード操作とは異なる。
カスタムモジュール操作特定のモジュールの入れ替えや削除など、モジュールレベルのシンプルな変更。グラフの接続性を自動的に追跡せず、関数レベルの操作は扱えない。