【初心者向け】PyTorch「torch.optim.Optimizer.state_dict()」のサンプルコードで理解を深める


PyTorchで機械学習モデルを構築する際、最適化アルゴリズムは学習プロセスにおいて重要な役割を果たします。代表的な最適化アルゴリズムの一つである「torch.optim」は、モデルのパラメータを更新し、損失関数を最小化するように努めます。

このチュートリアルでは、torch.optim.Optimizer クラスの重要なメソッドである state_dict() に焦点を当てます。このメソッドは、最適化アルゴリズムの状態を辞書形式で取得し、保存や復元に利用することができます。

state_dict() の役割

state_dict() メソッドは、以下の情報を格納した辞書を返します。

  • パラメータ状態
    各パラメータに対する現在の状態情報 (例:学習率、過去の勾配情報など)
  • パラメータグループ
    それぞれの学習率や更新規則などに関する情報を持つパラメータグループのリスト

この辞書を保存することで、以下の操作が可能になります。

  • モデルの共有
    保存された状態を共有することで、他のユーザーが同じモデルの学習を再現することができます。
  • 異なるデバイス間での転送
    保存された状態を別のデバイスに移行し、そこで学習を継続することができます。
  • 学習の再開
    保存された状態を使って、中断した学習を再開することができます。

state_dict() の使い方

state_dict() メソッドは、以下のコードのように簡単に呼び出すことができます。

optimizer_state_dict = optimizer.state_dict()

このコードを実行すると、現在の最適化アルゴリズムの状態が optimizer_state_dict 変数に格納されます。

保存と復元

state_dict() で取得した辞書は、torch.save() 関数を使って保存することができます。

torch.save(optimizer_state_dict, "optimizer_state.pt")

同様に、torch.load() 関数を使って、保存された状態を復元することができます。

loaded_state_dict = torch.load("optimizer_state.pt")

復元された状態を新しいオプティマイザにロードするには、load_state_dict() メソッドを使用します。

new_optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
new_optimizer.load_state_dict(loaded_state_dict)
  • 保存する前に、モデルとオプティマイザが同じデバイスにあることを確認してください。
  • 異なるバージョンの PyTorch で保存された状態をロードする場合、互換性の問題が発生する可能性があります。
  • state_dict() で保存されるのは、最適化アルゴリズムの状態のみです。モデルのパラメータ自体は含まれません。


モデルとオプティマイザの定義

まず、簡単なモデルとオプティマイザを定義します。

import torch

model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

state_dict の保存

次に、state_dict() メソッドを使用してオプティマイザの状態を保存します。

optimizer_state_dict = optimizer.state_dict()
torch.save(optimizer_state_dict, "optimizer_state.pt")

state_dict の復元

保存された状態を復元し、新しいオプティマイザにロードします。

loaded_state_dict = torch.load("optimizer_state.pt")
new_optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
new_optimizer.load_state_dict(loaded_state_dict)

復元されたオプティマイザの使用

復元されたオプティマイザを使用して、モデルのパラメータを更新することができます。

# ... トレーニングコード ...

optimizer = new_optimizer  # 復元されたオプティマイザを使用

説明

このコード例では、シンプルな線形モデルと SGD オプティマイザを定義しています。その後、state_dict() メソッドを使用してオプティマイザの状態を保存し、torch.save() 関数を使用してファイルに保存します。

次に、torch.load() 関数を使用して保存された状態を復元し、新しいオプティマイザにロードします。最後に、復元されたオプティマイザを使用してモデルのパラメータを更新します。

この例は、torch.optim.Optimizer.state_dict() の基本的な使用方法を示しています。実際の使用例では、より複雑なモデルとオプティマイザを使用する可能性がありますが、基本的な原理は同じです。

  • 保存と復元の操作は、異なるデバイス間で行うこともできます。
  • より複雑なモデルとオプティマイザを使用する場合は、コードをそれに応じて調整する必要があります。
  • このコードは、PyTorch 1.8.0 で動作確認済みです。


カスタムパラメータ

シンプルな代替方法として、モデルのパラメータと共にカスタムパラメータを保存する方法があります。この方法は、以下の利点があります。

  • 他のライブラリとの互換性が高い
  • シンプルで理解しやすい

一方、以下の欠点もあります。

  • モデルとパラメータを別々に保存する必要がある
  • オプティマイザの状態を完全に保存できない場合がある

カスタムパラメータを使用する例:

import torch

class MyOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr):
        super(MyOptimizer, self).__init__(params, lr)
        self.custom_param = 0.1

    def step(self):
        for group in self.param_groups:
            for param in group['params']:
                param.data -= self.lr * param.grad
        self.custom_param += 0.01

model = torch.nn.Linear(10, 1)
optimizer = MyOptimizer(model.parameters(), lr=0.01)

# ... トレーニングコード ...

state = {
    'model': model.state_dict(),
    'optimizer_custom_param': optimizer.custom_param,
}
torch.save(state, "checkpoint.pt")

カスタム状態オブジェクト

より複雑な代替方法として、カスタム状態オブジェクトを作成する方法があります。この方法は、以下の利点があります。

  • モデルと状態を一緒に保存できる
  • オプティマイザの状態を完全に保存できる
  • 他のライブラリとの互換性が低い可能性がある
  • カスタムオブジェクトの作成と実装が複雑になる

カスタム状態オブジェクトを使用する例:

import torch

class MyOptimizerState:
    def __init__(self, optimizer):
        self.param_groups = optimizer.param_groups
        self.state = optimizer.state

def save_optimizer_state(optimizer, filename):
    state_obj = MyOptimizerState(optimizer)
    torch.save(state_obj, filename)

def load_optimizer_state(filename, optimizer):
    state_obj = torch.load(filename)
    optimizer.param_groups = state_obj.param_groups
    optimizer.state = state_obj.state

model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# ... トレーニングコード ...

save_optimizer_state(optimizer, "optimizer_state.pt")

Pickleや joblib などのライブラリを使用して、torch.optim.Optimizer.state_dict() 以外でオブジェクトを保存することもできます。これらのライブラリは、より汎用性が高く、複雑なオブジェクトを保存するのに適しています。

適切な方法の選択

どの代替方法が適切かは、状況によって異なります。以下の点を考慮する必要があります。

  • 互換性
    他のライブラリやツールとの互換性はどうだろうか?
  • 複雑性
    カスタムオブジェクトの作成と実装はどれほど複雑か?
  • 必要な機能
    オプティマイザの状態を完全に保存する必要があるか、それとも部分的な保存で十分か?