【PyTorch】勾配ゼロ化でモデルを賢くする?「torch.optim.Optimizer.zero_grad()」のしくみとサンプルコード


torch.optim.Optimizer.zero_grad() は、PyTorchにおける最適化プロセスにおいて重要な役割を果たすメソッドです。このメソッドは、モデルのパラメータの勾配をゼロにリセットします。

勾配とは?

勾配は、モデルのパラメータが損失関数に与える影響を表すベクトルです。勾配降下法などの最適化アルゴリズムは、これらの勾配情報に基づいて、モデルのパラメータを更新し、損失関数を最小化していきます。

zero_grad() の必要性

しかし、勾配情報を累積的に利用していく場合、古い勾配情報の影響が残り、最適化の精度が低下する可能性があります。そこで、zero_grad() を用いて、毎回の更新ステップで勾配を初期化することで、この問題を解決します。

具体的な動作

zero_grad() を呼び出すと、モデル内のすべての最適化対象パラメータの .grad 属性がゼロに設定されます。.grad 属性は、各パラメータに対する勾配情報を持つテンソルです。

import torch

# モデル定義
model = ...

# 損失関数定義
criterion = ...

# オプティマイザ定義
optimizer = ...

# 訓練ループ
for epoch in range(num_epochs):
    # ...

    # 勾配計算
    optimizer.zero_grad()  # 勾配をゼロ化
    output = model(input)
    loss = criterion(output, target)
    loss.backward()

    # パラメータ更新
    optimizer.step()
  • PyTorch 1.6以降では、optimizer.step() メソッド内部で自動的に zero_grad() が呼び出されるため、明示的に呼び出す必要はなくなりました。ただし、勾配を累積的に利用したい場合は、引き続き zero_grad() を明示的に呼び出す必要があります。
  • zero_grad() は、勾配計算の前に呼び出す必要があります。勾配計算後に呼び出しても、効果がありません。
  • zero_grad() は、モデル内のすべての最適化対象パラメータに対して適用されます。一部のパラメータのみの勾配をリセットしたい場合は、個別に .grad 属性を操作する必要があります。


勾配をゼロ化してパラメータを更新する基本的な例

import torch

# モデル定義
model = nn.Linear(2, 1)

# 損失関数定義
criterion = nn.MSELoss()

# オプティマイザ定義
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 訓練データ
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
y = torch.tensor([3, 5], dtype=torch.float32)

# 訓練ループ
for epoch in range(10):
    # 予測と損失計算
    pred = model(x)
    loss = criterion(pred, y)

    # 勾配計算
    optimizer.zero_grad()  # 勾配をゼロ化
    loss.backward()

    # パラメータ更新
    optimizer.step()

    # 訓練状況の確認
    print(f'Epoch {epoch + 1}: loss = {loss.item():.4f}')

勾配を累積的に利用する場合

import torch

# モデル定義
model = nn.Linear(2, 1)

# 損失関数定義
criterion = nn.MSELoss()

# オプティマイザ定義
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 訓練データ
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
y = torch.tensor([3, 5, 7], dtype=torch.float32)

# 訓練ループ
for epoch in range(10):
    # 勾配をゼロ化
    optimizer.zero_grad()

    # バッチループ
    for i in range(len(x)):
        # 予測と損失計算
        pred = model(x[i])
        loss = criterion(pred, y[i])

        # 勾配計算
        loss.backward()

    # パラメータ更新
    optimizer.step()

    # 訓練状況の確認
    print(f'Epoch {epoch + 1}: loss = {loss.item():.4f}')
import torch

# モデル定義
model = nn.Linear(2, 1)

# 損失関数定義
criterion = nn.MSELoss()

# オプティマイザ定義
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 訓練データ
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
y = torch.tensor([3, 5, 7], dtype=torch.float32)

# 訓練ループ
for epoch in range(10):
    # 予測と損失計算
    pred = model(x)
    loss = criterion(pred, y)

    # パラメータ更新
    optimizer.step()

    # 訓練状況の確認
    print(f'Epoch {epoch + 1}: loss = {loss.item():.4f}')

1つ目の例は、最も基本的な使用方法です。この例では、各訓練ステップで勾配をゼロ化し、パラメータを更新しています。



model.parameters() の .grad 属性を直接操作する

import torch

# モデル定義
model = nn.Linear(2, 1)

# 損失関数定義
criterion = nn.MSELoss()

# オプティマイザ定義
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 訓練データ
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
y = torch.tensor([3, 5, 7], dtype=torch.float32)

# 訓練ループ
for epoch in range(10):
    # 予測と損失計算
    pred = model(x)
    loss = criterion(pred, y)

    # 勾配計算
    loss.backward()

    # 勾配をゼロ化
    for param in model.parameters():
        param.grad = torch.zeros_like(param.grad)

    # パラメータ更新
    optimizer.step()

    # 訓練状況の確認
    print(f'Epoch {epoch + 1}: loss = {loss.item():.4f}')

利点

  • 特定のパラメータのみの勾配をリセットできる
  • コードが簡潔になる

欠点

  • すべての勾配パラメータに対して明示的に操作が必要
  • for ループによる反復処理が必要

.grad 属性にゼロ値を代入する関数を作成する

import torch

def zero_grads(params):
    for param in params:
        param.grad = torch.zeros_like(param.grad)

# モデル定義
model = nn.Linear(2, 1)

# 損失関数定義
criterion = nn.MSELoss()

# オプティマイザ定義
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 訓練データ
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
y = torch.tensor([3, 5, 7], dtype=torch.float32)

# 訓練ループ
for epoch in range(10):
    # 予測と損失計算
    pred = model(x)
    loss = criterion(pred, y)

    # 勾配計算
    loss.backward()

    # 勾配をゼロ化
    zero_grads(model.parameters())

    # パラメータ更新
    optimizer.step()

    # 訓練状況の確認
    print(f'Epoch {epoch + 1}: loss = {loss.item():.4f}')

利点

  • 関数呼び出しで簡潔に記述できる
  • コードの再利用性が高い

欠点

  • 関数定義が必要
import torch
import torch.nn.utils as utils

# モデル定義
model = nn.Linear(2, 1)

# 損失関数定義
criterion = nn.MSELoss()

# オプティマイザ定義
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 訓練データ
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
y = torch.tensor([3, 5, 7], dtype=torch.float32)

# 訓練ループ
for epoch in range(10):
    # 予測と損失計算
    pred = model(x)
    loss = criterion(pred, y)

    # 勾配計算
    loss.backward()

    # 勾配をクリップ
    utils.clip_grad.global_norm_(model.parameters(), max_norm=0.1)  # 勾配の大きさの最大値を0.1に制限

    # パラメータ更新
    optimizer.step()

    # 訓練状況の確認
    print(