【初心者歓迎】PyTorch GraphModule をファイルに保存する一番簡単な方法

2025-05-31

  1. Python コード (.py ファイル)
    モデルの構造を定義する Python コードが保存されます。これには、モデルの各ノード(演算やパラメータなど)がどのように接続されているかが記述されています。このコードを見ることで、モデルがどのような計算グラフとして表現されているかを理解することができます。

  2. パラメータとバッファ (.pt ファイル)
    モデルが持つ学習可能なパラメータ(重み、バイアスなど)や、学習には使われないもののモデルの状態を保持するために使われるバッファ(BatchNorm の running mean/variance など)が、PyTorch の torch.save() 関数によって .pt ファイルとして保存されます。

  3. 設定ファイル (.json ファイル)
    モデルのメタデータや、どのように保存されたかに関する情報などが JSON 形式で保存されます。これには、使用された PyTorch のバージョン、保存時のタイムスタンプ、パラメータとバッファのファイル名などが含まれることがあります。

このメソッドの主な目的は以下の通りです。

  • 学習済みパラメータの保持
    学習済みの重みをモデル構造と一緒に保存することで、推論やファインチューニングのためにモデルをロードし直すことができます。
  • 中間表現の保存と共有
    FX で変換されたモデルの中間表現を保存し、他の人と共有したり、後で再利用したりすることができます。
  • デバッグ
    モデルの変換や最適化の過程で、意図しないグラフ構造になっていないかなどを確認する際に便利です。
  • モデルの構造の可視化
    保存された Python コードを見ることで、複雑なモデルの内部構造やデータフローを理解するのに役立ちます。

基本的な使い方

import torch
import torch.nn as nn
from torch.fx import symbolic_trace

# 簡単なモデルの定義
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(20, 5)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

# モデルのインスタンスを作成
model = MyModule()

# symbolic_trace を使って GraphModule に変換
traced_model = symbolic_trace(model)

# 保存先のフォルダを指定
output_folder = "my_traced_model"

# モデルをフォルダに保存
traced_model.to_folder(output_folder, module_name="MyTracedModule")

print(f"モデルは '{output_folder}' フォルダに保存されました。")

上記の例では、MyModule のインスタンスを symbolic_traceGraphModule に変換し、to_folder() メソッドを使って "my_traced_model" というフォルダに保存しています。module_name 引数で、生成される Python ファイル内のモジュール名を指定できます。

保存されたフォルダの中には、例えば以下のようなファイルが含まれているでしょう。

  • my_traced_model/metadata.json: 保存に関するメタデータを含む JSON ファイル (ファイル名は異なる場合があります)
  • my_traced_model/parameters.pt: モデルのパラメータとバッファを保存した PyTorch の Pickle ファイル
  • my_traced_model/my_traced_module.py: モデルのグラフ構造を定義する Python コード


保存先のフォルダが存在しない、または書き込み権限がない

  • トラブルシューティング
    • 指定した output_folder が存在するかどうかを確認してください。存在しない場合は、事前に os.makedirs(output_folder, exist_ok=True) などを使ってフォルダを作成する必要があります。
    • 指定したフォルダに対する書き込み権限があるか確認してください。特に、保護されたディレクトリに保存しようとしている場合に問題が発生しやすいです。
  • エラー
    FileNotFoundError (フォルダが存在しない場合) や PermissionError (書き込み権限がない場合) が発生することがあります。

module_name が Python の命名規則に違反している

  • トラブルシューティング
    • module_name 引数に指定する文字列が、有効な Python のモジュール名であることを確認してください。英数字とアンダースコアのみを使用し、数字で始まらないようにする必要があります。
  • エラー
    生成される Python ファイル名やモジュール名が Python の識別子のルール(先頭が文字またはアンダースコア、以降は文字、数字、アンダースコアのみ)に従っていない場合、エラーが発生する可能性があります。

保存しようとしている GraphModule が空である、または正しくトレースされていない

  • トラブルシューティング
    • symbolic_trace などを使って、モデルが正しく GraphModule に変換されているかを確認してください。print(traced_model.graph) などでグラフの構造を出力してみるのも有効です。
    • トレース時にエラーが発生していないか確認してください。トレースできない操作が含まれている場合、エラーメッセージが表示されることがあります。
  • エラー
    空の GraphModule や、期待通りにモデルがトレースされていない GraphModule を保存しようとしても、意図したファイルが生成されないか、後でロードした際に問題が発生する可能性があります。

パラメータやバッファの保存に失敗する

  • トラブルシューティング
    • 保存先のディスクに十分な空き容量があるか確認してください。
    • ファイルシステムに問題がないか確認してください。他のファイルを同じ場所に保存できるか試してみるのも良いでしょう。
  • エラー
    まれに、モデルのパラメータやバッファの保存 (.pt ファイルの書き込み) に失敗することがあります。これは、ディスク容量の不足やファイルシステムの破損などが原因となる可能性があります。

メタデータ (.json ファイルなど) の書き込みエラー

  • トラブルシューティング
    • 上記 1 と 4 のトラブルシューティングと同様に、フォルダの権限とディスク容量を確認してください。
  • エラー
    設定ファイルなどのメタデータの保存中にエラーが発生することがあります。これも、書き込み権限やディスクの問題が原因となる可能性があります。

カスタムの Node が含まれている場合の取り扱い

  • トラブルシューティング
    • カスタムの Node がどのように実装されているかを確認し、保存された Python コードがその実装に依存していないか確認してください。必要であれば、カスタムの Node を再構築するための追加のコードや情報も保存することを検討してください。
  • 注意点
    GraphModule に、標準的な PyTorch の演算以外のカスタムの Node が含まれている場合、生成される Python コードがそのまま実行できるとは限りません。カスタムの Node が依存するコードも適切に管理し、ロード時に利用できるようにする必要があります。
  • ログ出力を活用する
    必要に応じて、保存処理の前後に変数の状態やグラフの構造をログ出力して確認することで、問題の特定に役立つことがあります。
  • PyTorch のバージョンを確認する
    PyTorch のバージョンによって、挙動やサポートされる機能が異なる場合があります。使用しているバージョンに対応したドキュメントや情報を参照してください。
  • 簡単な例で試す
    まずは簡単なモデルで to_folder() を試してみて、基本的な動作を確認すると良いでしょう。
  • エラーメッセージをよく読む
    PyTorch が出力するエラーメッセージには、問題の原因に関する重要な情報が含まれています。


基本的な使用例

これは、最も基本的な to_folder() の使い方を示す例です。簡単な線形層を持つモデルをトレースし、指定したフォルダに保存します。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace
import os

# 保存先のフォルダ名
output_folder = "simple_linear_model"

# モデルの定義
class SimpleLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x):
        return self.linear(x)

# モデルのインスタンスを作成
model = SimpleLinear(10, 5)

# モデルを symbolic_trace で GraphModule に変換
traced_model = symbolic_trace(model)

# 保存先のフォルダが存在しない場合は作成
os.makedirs(output_folder, exist_ok=True)

# GraphModule をフォルダに保存
traced_model.to_folder(output_folder, module_name="SimpleLinearModule")

print(f"モデルは '{output_folder}' フォルダに保存されました。")

このコードを実行すると、simple_linear_model というフォルダが作成され、その中にモデルの構造を定義した simple_linear_module.py、パラメータを保存した parameters.pt、そしてメタデータを含むファイル(通常は metadata.json)が保存されます。

module_name を指定する例

module_name 引数を使うことで、生成される Python ファイル内のモジュール名を制御できます。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace
import os

output_folder = "named_module_model"
module_name = "MyCustomModel"

class AnotherSimpleLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x):
        return self.linear(x)

model = AnotherSimpleLinear(20, 10)
traced_model = symbolic_trace(model)

os.makedirs(output_folder, exist_ok=True)
traced_model.to_folder(output_folder, module_name=module_name)

print(f"モデルは '{output_folder}' フォルダに '{module_name}.py' として保存されました。")

この例では、保存される Python ファイルの名前は my_custom_model.py になり、その中のモジュール名も MyCustomModel となります。

より複雑なモデルの保存例

複数の層を持つ少し複雑なモデルでも、同様に to_folder() を使用できます。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace
import os

output_folder = "complex_model"

class ComplexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(32 * 5 * 5, 10) # 適当なサイズ

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 32 * 5 * 5)
        x = self.fc(x)
        return x

model = ComplexNet()
traced_model = symbolic_trace(model)

os.makedirs(output_folder, exist_ok=True)
traced_model.to_folder(output_folder, module_name="ComplexModel")

print(f"複雑なモデルは '{output_folder}' フォルダに保存されました。")

この例では、畳み込み層やプーリング層、全結合層を含む CNN モデルをトレースして保存しています。保存された Python ファイルには、これらの層の接続関係が記述されます。

保存されたモデルのロード

to_folder() で保存したモデルは、torch.fx.GraphModule.from_folder() を使ってロードできます。

import torch
from torch.fx import GraphModule

# 保存したフォルダ名
loaded_folder = "simple_linear_model"

# フォルダから GraphModule をロード
loaded_model = GraphModule.from_folder(loaded_folder)

# ロードされたモデルの構造を表示
print("ロードされたモデルのグラフ:")
print(loaded_model.graph)

# ロードされたモデルを使って推論を行う (パラメータもロードされている)
input_tensor = torch.randn(1, 10)
output = loaded_model(input_tensor)
print("ロードされたモデルの出力:", output)

この例では、最初に保存した simple_linear_model フォルダから GraphModule をロードし、そのグラフ構造を表示しています。ロードされたモデルは、保存時のパラメータも保持しているため、そのまま推論に使うことができます。



torch.save() を使用して GraphModule オブジェクト全体を保存する

最も直接的な代替方法は、torch.save() 関数を使って GraphModule オブジェクトそのものを保存することです。これにより、モデルのグラフ構造とパラメータが単一のファイルにシリアライズされます。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace

# モデルの定義
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 2)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
traced_model = symbolic_trace(model)

# GraphModule オブジェクトをファイルに保存
save_path = "traced_model.pt"
torch.save(traced_model, save_path)

print(f"GraphModule オブジェクトは '{save_path}' に保存されました。")

# 保存した GraphModule オブジェクトをロード
loaded_model = torch.load(save_path)
print("ロードされたモデルのグラフ:")
print(loaded_model.graph)

利点

  • シンプル
    単一の関数呼び出しで保存とロードが完了します。

欠点

  • バージョン依存
    PyTorch のバージョンが変わると、ロードできなくなる可能性があります。
  • 人間が読みにくい
    保存されたファイルはバイナリ形式であり、直接内容を理解することは困難です。

モデルの構造(グラフ)を文字列として保存する

GraphModule オブジェクトの .graph 属性は、モデルのグラフ構造を表す Graph オブジェクトです。このグラフを文字列として取得し、ファイルに保存することができます。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace

class AnotherSimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
        self.linear = nn.Linear(3, 4)

    def forward(self, x):
        return self.linear(self.relu(x))

model = AnotherSimpleModel()
traced_model = symbolic_trace(model)

# グラフ構造を文字列として取得
graph_string = str(traced_model.graph)

# 文字列をファイルに保存
graph_file_path = "model_graph.txt"
with open(graph_file_path, "w") as f:
    f.write(graph_string)

print(f"モデルのグラフ構造は '{graph_file_path}' に保存されました。")

# パラメータは別途 torch.save() で保存する必要がある
parameters_path = "model_params.pt"
torch.save(model.state_dict(), parameters_path)
print(f"モデルのパラメータは '{parameters_path}' に保存されました。")

# ロードする場合は、文字列からグラフを再構築し、パラメータをロードする必要がある (複雑)

利点

  • グラフ構造を人間が読める形式で保存できる
    テキストファイルとして保存されるため、モデルの構造を理解しやすくなります。

欠点

  • パラメータは別途保存・ロードが必要
    グラフ構造とパラメータを別々に管理する必要があります。
  • 再構築が複雑
    保存した文字列から GraphModule を直接再構築する簡単な方法はありません。手動で解析し、ノードやエッジを再作成する必要があります。

ONNX (Open Neural Network Exchange) 形式でエクスポートする

ONNX は、異なるフレームワーク間でモデルを交換するためのオープンな標準形式です。PyTorch のモデル(GraphModule を含む)を ONNX 形式にエクスポートできます。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace
import torch.onnx

class YetAnotherModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
        self.linear = nn.Linear(2, 1)

    def forward(self, x):
        return self.linear(self.sigmoid(x))

model = YetAnotherModel()
traced_model = symbolic_trace(model)

# ダミーの入力テンソル
dummy_input = torch.randn(1, 2)

# ONNX 形式でエクスポート
onnx_path = "model.onnx"
torch.onnx.export(traced_model, dummy_input, onnx_path)

print(f"モデルは ONNX 形式で '{onnx_path}' にエクスポートされました。")

# ONNX モデルは他のフレームワークや ONNX ランタイムでロードして使用できる

利点

  • 標準化された形式
    モデルの交換やデプロイメントに適しています。
  • 互換性
    他の深層学習フレームワークや ONNX ランタイムでロードして使用できます。

欠点

  • 直接的な Python コードとしての保存ではない
    モデルの構造は ONNX のプロトコルバッファ形式で保存されます。
  • FX 固有の情報が失われる可能性
    ONNX はすべての PyTorch FX の機能を完全にサポートしているわけではありません。

カスタムのシリアライズ/デシリアライズ処理を実装する

より柔軟な方法として、GraphModule の内部構造(ノード、エッジ、パラメータなど)を直接操作し、カスタムの形式で保存およびロードする処理を実装することも可能です。これには、traced_model.graph.nodestraced_model.state_dict() などを利用します。

import torch
import torch.nn as nn
from torch.fx import symbolic_trace
import json

class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tanh = nn.Tanh()
        self.linear = nn.Linear(4, 3)

    def forward(self, x):
        return self.linear(self.tanh(x))

model = CustomModel()
traced_model = symbolic_trace(model)

# カスタム形式でグラフ構造とパラメータを保存 (JSON + torch.save)
custom_data = {
    "graph": [str(node) for node in traced_model.graph.nodes],
    "state_dict_path": "custom_params.pt"
}
with open("custom_model.json", "w") as f:
    json.dump(custom_data, f)
torch.save(traced_model.state_dict(), custom_data["state_dict_path"])

print("モデルをカスタム形式で保存しました。")

# ロード処理はさらに複雑になる

利点

  • 完全な制御
    保存形式を完全にカスタマイズできます。
  • 保守が難しい
    カスタム形式の変更には対応が必要です。
  • 実装が複雑
    保存とロードのロジックを自分で記述する必要があります。