深層学習 PyTorch checkpoint でGPUメモリを節約する

2025-05-27

torch.utils.checkpoint.checkpoint() は、PyTorchで非常に長いシーケンスや深いネットワークを学習させる際に、メモリ使用量を削減するために用いられる重要な関数です。その主な目的は、順伝播(forward pass)時に計算された中間的な活性化(activation)をすべてメモリに保持する代わりに、必要になった時点で再計算することです。

具体的には、checkpoint() で囲まれた関数(通常はモデルのレイヤーや一連のレイヤーを含む)の順伝播は通常通り実行されますが、その際に計算される中間的な活性化は保存されません。代わりに、逆伝播(backward pass)が必要になった際に、もう一度同じ順伝播を計算し直すことで、必要な活性化を再生成します。

この仕組みによって、逆伝播時に勾配を計算するために必要となる中間活性化のメモリフットプリントを大幅に削減できます。特に、Transformerのようなメモリ消費量の大きいモデルや、非常に長いシーケンスデータを扱う場合に有効です。

checkpoint() の基本的な使い方

checkpoint() 関数は、以下の引数を取ります。

  • **kwargs: function に渡すキーワード引数。
  • *args: function に渡す位置引数。
  • function: チェックポイントする関数(通常はモデルの一部分)。

そして、チェックポイントされた関数の出力を返します。

コード例

import torch
from torch.utils.checkpoint import checkpoint

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(10, 20)
        self.linear2 = torch.nn.Linear(20, 30)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        # linear2 の出力をチェックポイント
        x = checkpoint(self._intermediate_forward, x)
        return x

    def _intermediate_forward(self, x):
        return torch.relu(self.linear2(x))

# モデルのインスタンス化と入力データの作成
model = MyModule()
input_data = torch.randn(5, 10, requires_grad=True)

# 順伝播
output = model(input_data)

# 損失関数の計算
loss = output.mean()

# 逆伝播
loss.backward()

print("勾配計算完了")

この例では、MyModuleforward メソッド内で linear2 の出力を生成する部分 (self._intermediate_forward) が checkpoint() で囲まれています。これにより、_intermediate_forward の順伝播で生成される中間的な活性化は保存されず、逆伝播時に再計算されます。

checkpoint() の利点と注意点

利点

  • 勾配計算の効率化
    メモリ不足によるエラーを回避し、学習を継続できます。
  • メモリ使用量の削減
    特に深いネットワークや長いシーケンスにおいて顕著です。これにより、より大きなモデルやシーケンスを、限られたGPUメモリで学習できるようになります。
  • torch.no_grad() の影響
    torch.no_grad() コンテキスト内では checkpoint() は通常の関数呼び出しと同様に動作し、中間活性化は保存され、再計算は行われません。
  • 互換性
    checkpoint() で囲む関数は、入力と出力がタプルまたは単一のテンソルであることが推奨されます。複雑なデータ構造を扱う場合は、事前に適切に処理する必要があります。
  • 計算時間の増加
    逆伝播時に順伝播を再計算するため、学習の総計算時間は増加します。したがって、メモリ使用量と計算時間のトレードオフを考慮する必要があります。


一般的なエラーとトラブルシューティング

  1. RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn (テンソルの要素0は勾配を必要とせず、grad_fnを持ちません。)

    • 原因
      チェックポイントされた関数の入力テンソルの一部が requires_grad=True でない場合に、逆伝播時に勾配を計算しようとして発生することがあります。checkpoint() は逆伝播のために順伝播を再計算しますが、その際に勾配を追跡するために必要な情報が不足している可能性があります。
    • 解決策
      checkpoint() に渡す入力テンソルは、勾配計算が必要な場合は requires_grad=True に設定されていることを確認してください。

    <!-- end list -->

    input_tensor_no_grad = torch.randn(10) # requires_grad=False (デフォルト)
    input_tensor_with_grad = torch.randn(10, requires_grad=True)
    
    # エラーの可能性
    # output = checkpoint(my_module, input_tensor_no_grad)
    # loss = output.mean()
    # loss.backward() # エラーが発生する可能性
    
    # 解決策
    output = checkpoint(my_module, input_tensor_with_grad)
    loss = output.mean()
    loss.backward()
    
    • 原因
      • チェックポイントする範囲が不適切
        細かすぎる範囲で checkpoint() を多用すると、再計算のオーバーヘッドが大きくなり、期待したほどのメモリ削減効果が得られないことがあります。
      • 大きなテンソルがチェックポイントの外で保持されている
        checkpoint() で囲まれた部分の中間活性化は再計算されますが、それ以外の大きなテンソルがメモリに保持されたままだと、メモリ使用量の削減効果が限定的になります。
      • 最適化されていない関数の再計算
        チェックポイントされた関数自体が非効率な処理を行っている場合、再計算のコストが高くなります。
    • 解決策
      • 適切な粒度でチェックポイント
        モデルの構造やメモリプロファイルに基づいて、効果的な範囲をチェックポイントするように調整してください。一般的には、TransformerブロックやResNetブロックのような比較的大きな処理単位でチェックポイントすることが推奨されます。
      • 不要なテンソルの解放
        チェックポイントの外で不要になったテンソルは、明示的に del したり、スコープから外したりすることで、メモリから解放するように心がけてください。
      • チェックポイントする関数の最適化
        チェックポイントする関数自体の処理を効率化することも重要です。
      • プロファイリング
        PyTorchのプロファイリングツール (torch.profiler) を使用して、メモリ使用量と計算時間のボトルネックを特定し、チェックポイントの適用範囲を検討するのに役立ちます。
  2. torch.no_grad() コンテキスト内での誤動作

    • 原因
      torch.no_grad() コンテキスト内で checkpoint() を使用すると、中間活性化は保存され、再計算は行われません。これは、勾配計算が不要な推論時には意図された動作ですが、学習中に誤って使用すると、メモリ削減の効果が得られず、予期せぬ動作を引き起こす可能性があります。
    • 解決策
      学習時には torch.no_grad() コンテキストの外で checkpoint() を使用してください。推論時には、メモリ使用量を削減する必要がない場合や、再計算のコストを避けたい場合は、checkpoint() を使用しないか、torch.no_grad() コンテキスト内で使用することを検討してください。
  3. 複雑な制御フローとの組み合わせ

    • 原因
      チェックポイントされた関数内で複雑な条件分岐やループ処理を行うと、逆伝播時の再計算が困難になったり、予期せぬ動作を引き起こしたりする可能性があります。
    • 解決策
      チェックポイントする関数は、比較的単純な順伝播処理のまとまりに留めることが推奨されます。複雑な制御フローは、チェックポイントの外で処理するようにモデルを設計することを検討してください。

トラブルシューティングのヒント

  • PyTorchのドキュメントを確認する
    torch.utils.checkpoint.checkpoint() の詳細な仕様や注意点について、公式ドキュメントを参照してください。
  • 最小限の再現コードを作成する
    問題を特定しやすくするために、エラーが発生する最小限のコードを作成して試してみてください。
  • エラーメッセージをよく読む
    エラーメッセージには、問題の原因に関する重要な情報が含まれています。


基本的な使用例

この例では、簡単な線形層を2つ持つモジュールを作成し、その一方の順伝播処理を checkpoint() で囲みます。

import torch
from torch.utils.checkpoint import checkpoint

class SimpleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        # linear2 の処理をチェックポイント
        x = checkpoint(self._linear2_forward, x)
        return x

    def _linear2_forward(self, x):
        return torch.relu(self.linear2(x))

# モデルのインスタンス化
input_size = 10
hidden_size = 20
output_size = 5
model = SimpleModel(input_size, hidden_size, output_size)

# 入力データの作成
input_data = torch.randn(1, input_size, requires_grad=True)

# 順伝播
output = model(input_data)
print("順伝播の出力:", output)

# 損失関数の計算
loss = output.mean()

# 逆伝播
loss.backward()
print("逆伝播完了")

# チェックポイントされたレイヤーのパラメータの勾配を確認
print("linear2.weight の勾配:", model.linear2.weight.grad)

この例では、SimpleModelforward メソッド内で、self.linear2 を適用する処理 (self._linear2_forward) が checkpoint() で囲まれています。順伝播時には linear2 の出力は保存されませんが、逆伝播時に checkpoint()_linear2_forward を再実行し、勾配計算に必要な中間活性化を生成します。

複数の引数を取る関数のチェックポイント

checkpoint() に渡す関数が複数の引数を取る場合、それらは *args として checkpoint() に渡されます。

import torch
from torch.utils.checkpoint import checkpoint

def multi_input_function(a, b, c):
    return a + b * torch.relu(c)

# 入力テンソルの作成
tensor_a = torch.randn(2, 3, requires_grad=True)
tensor_b = torch.randn(2, 3, requires_grad=True)
tensor_c = torch.randn(2, 3, requires_grad=True)

# チェックポイントされた関数の呼び出し
output = checkpoint(multi_input_function, tensor_a, tensor_b, tensor_c)
print("出力:", output)

# 損失関数の計算と逆伝播
loss = output.mean()
loss.backward()

print("tensor_a の勾配:", tensor_a.grad)
print("tensor_b の勾配:", tensor_b.grad)
print("tensor_c の勾配:", tensor_c.grad)

ここでは、multi_input_function が3つのテンソルを入力として受け取ります。checkpoint() を呼び出す際には、関数自体とそれに渡す引数を順番に指定します。

キーワード引数を持つ関数のチェックポイント

キーワード引数を持つ関数をチェックポイントする場合、それらは **kwargs として checkpoint() に渡されます。

import torch
from torch.utils.checkpoint import checkpoint

def keyword_argument_function(x, factor=2):
    return x * factor

# 入力テンソルの作成
input_tensor = torch.randn(2, 3, requires_grad=True)

# チェックポイントされた関数の呼び出し (キーワード引数あり)
output = checkpoint(keyword_argument_function, input_tensor, factor=3)
print("出力:", output)

# 損失関数の計算と逆伝播
loss = output.mean()
loss.backward()

print("input_tensor の勾配:", input_tensor.grad)

この例では、keyword_argument_functionx とキーワード引数 factor を取ります。checkpoint() を呼び出す際に、位置引数とキーワード引数を分けて指定します。

より複雑なモデルでの応用 (Transformer Blockの例)

Transformerのような深いネットワークでは、各Transformer Blockを checkpoint() で囲むことで、メモリ使用量を効果的に削減できます。

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output, _ = self.mha(x, x, x)
        x = x + attn_output
        x = self.norm(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = self.linear2(x)
        x = self.norm(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attn = SelfAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_dim)

    def forward(self, x):
        x = self.attn(x)
        x = self.ff(x)
        return x

class SimpleTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            # 各Transformer Blockをチェックポイント
            x = checkpoint(layer, x)
        return x

# モデルのインスタンス化
embed_dim = 64
num_heads = 8
ff_dim = 256
num_layers = 4
model = SimpleTransformer(embed_dim, num_heads, ff_dim, num_layers)

# 入力データの作成
batch_size = 2
seq_len = 20
input_data = torch.randn(batch_size, seq_len, embed_dim, requires_grad=True)

# 順伝播
output = model(input_data)
print("Transformerの出力形状:", output.shape)

# 損失関数の計算と逆伝播
loss = output.mean()
loss.backward()
print("Transformerの逆伝播完了")

この例では、SimpleTransformer の各 TransformerBlockforward メソッドが checkpoint() で囲まれています。これにより、各ブロックの中間活性化がメモリに保持されることなく、逆伝播時に必要に応じて再計算されます。これは、非常に深いTransformerモデルを学習する際に、メモリ使用量を大幅に削減するのに役立ちます。



勾配蓄積 (Gradient Accumulation)

  • 実装
    学習ループ内で、指定したステップ数ごとに optimizer.step() を呼び出す前に loss.backward() を複数回実行します。各 backward() 呼び出しの前に optimizer.zero_grad() を行わないことが重要です。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    model = nn.Linear(10, 1)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.MSELoss()
    accumulation_steps = 4  # 勾配を蓄積するステップ数
    batch_size = 16
    data_size = 100
    inputs = torch.randn(data_size, 10)
    targets = torch.randn(data_size, 1)
    dataloader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(inputs, targets), batch_size=batch_size)
    
    for i, (batch_inputs, batch_targets) in enumerate(dataloader):
        outputs = model(batch_inputs)
        loss = criterion(outputs, batch_targets)
        loss = loss / accumulation_steps  # 勾配を平均化するためにスケール
        loss.backward()
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
    
  • checkpoint() との違い
    checkpoint() が順伝播の中間活性化を再計算することでメモリを節約するのに対し、勾配蓄積はバッチ処理の効率を高めることで間接的にメモリ問題を緩和します。計算時間は増加しますが、再計算のオーバーヘッドはありません。

モデル並列 (Model Parallelism)

  • 実装
    PyTorchには torch.nn.parallel.DistributedDataParalleltorch.nn.DataParallel (単一ノード内での並列化) などのユーティリティがあり、モデルのレイヤーを異なるデバイスに明示的に配置することも可能です。

    import torch
    import torch.nn as nn
    
    class LargeModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.layer1 = nn.Linear(10, 1000).to('cuda:0')
            self.layer2 = nn.Linear(1000, 2000).to('cuda:1')
            self.layer3 = nn.Linear(2000, 1).to('cuda:0')
    
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x.to('cuda:1'))
            x = self.layer3(x.to('cuda:0'))
            return x
    
    model = LargeModel()
    input_data = torch.randn(16, 10).to('cuda:0')
    output = model(input_data)
    
  • checkpoint() との違い
    checkpoint() が単一のGPU内でのメモリ効率化を目指すのに対し、モデル並列は複数のGPUを活用してモデル全体のメモリフットプリントを分散させます。

量子化 (Quantization)

  • 実装
    PyTorchには量子化のためのツールキット (torch.quantization) が用意されており、学習後の量子化 (Post-Training Quantization) や量子化対応学習 (Quantization-Aware Training) を行うことができます。
  • checkpoint() との違い
    checkpoint() が計算グラフの再構築によってメモリを節約するのに対し、量子化はデータの表現方法をよりコンパクトにすることでメモリフットプリントを削減します。

混合精度学習 (Mixed Precision Training)

  • 実装
    PyTorchでは torch.cuda.amp モジュールを利用することで、自動的に混合精度学習を行うことができます。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.cuda.amp import GradScaler, autocast
    
    model = nn.Linear(10, 1).cuda()
    optimizer = optim.Adam(model.parameters())
    criterion = nn.MSELoss().cuda()
    scaler = GradScaler()
    
    inputs = torch.randn(16, 10).cuda()
    targets = torch.randn(16, 1).cuda()
    
    optimizer.zero_grad()
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)
    scaler.scale(loss).backward()
    
  • checkpoint() との違い
    checkpoint() が中間活性化の保存方法を工夫するのに対し、混合精度学習はテンソルのデータ型を使い分けることでメモリ効率を高めます。

scaler.step(optimizer) scaler.update() ```

より効率的なネットワーク設計

  • checkpoint() との違い
    checkpoint() が既存のモデルのメモリ効率を高めるためのテクニックであるのに対し、効率的なネットワーク設計は、そもそもメモリ使用量が少ないモデルを構築することを目指します。