【初心者向け】PyTorchで画像認識CNNを作る際に役立つ「torch.nn.MaxPool2d」のわかりやすい解説


torch.nn.MaxPool2dは、畳み込みニューラルネットワーク(CNN)で使用されるプーリング層の1つです。入力された画像テンソルに対して、指定されたカーネルサイズで窓移動を行い、各窓内の最大値を抽出します。画像サイズを縮小しながら、特徴マップの重要な特徴を抽出するために役立ちます。

基本的な使い方

import torch.nn as nn

# 畳み込み層の後にプーリング層を定義
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

# 畳み込み層の出力をプーリング層に入力
output = max_pool(input_tensor)

この例では、kernel_size=2と設定しているので、2x2の窓で最大値プーリングを行います。stride=2と設定しているので、プーリングされた出力テンソルのサイズは元の入力テンソルの半分のサイズになります。

  • return_indices (bool, optional): 最大値の位置を返すかどうかを指定します。デフォルトはFalseです。
  • dilation (tuple, optional): 拡張率を指定します。(D_h, D_w)の形式で、D_hはカーネルの高さ方向の拡張率、D_wはカーネルの幅方向の拡張率を表します。デフォルトは1です。
  • padding (tuple, optional): パディングを指定します。(P_h, P_w)の形式で、P_hは入力テンソルの高さ方向のパディング、P_wは入力テンソルの幅方向のパディングを表します。デフォルトは0です。
  • stride (tuple, optional): ストライドを指定します。(H_s, W_s)の形式で、H_sは出力テンソルの高さ方向のストライド、W_sは出力テンソルの幅方向のストライドを表します。デフォルトはkernel_sizeと同じ値です。
  • kernel_size (tuple): カーネルサイズを指定します。(H, W)の形式で、Hはカーネルの高さ、Wはカーネルの幅を表します。
  • プーリング層は、過学習を防ぎ、モデルの汎化性能を向上させるのに役立ちます。
  • 他のプーリング層として、AveragePool2dL1Pool2dなどがあります。
  • MaxPool2dは、画像認識や物体検出などのタスクでよく使用されます。


基本的な使い方

import torch
import torch.nn as nn

# 入力データを作成
input_tensor = torch.randn(1, 3, 224, 224)

# 畳み込み層とプーリング層を定義
conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

# 畳み込み層とプーリング層を順に適用
output = conv(input_tensor)
output = max_pool(output)

# 出力テンソルのサイズを確認
print(output.size())

このコードを実行すると、以下の出力が得られます。

torch.Size([1, 16, 112, 112])

入力テンソルは1x3x224x224のサイズでしたが、プーリング層を通過したことで1x16x112x112のサイズになりました。

ストライドとパディング

import torch
import torch.nn as nn

# 入力データを作成
input_tensor = torch.randn(1, 3, 28, 28)

# 畳み込み層とプーリング層を定義
conv = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
max_pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0)

# 畳み込み層とプーリング層を順に適用
output = conv(input_tensor)
output = max_pool(output)

# 出力テンソルのサイズを確認
print(output.size())
torch.Size([1, 16, 13, 13])

stride=2と設定しているので、プーリングされた出力テンソルのサイズは元の入力テンソルの半分になります。また、padding=1と設定しているので、入力テンソルの境界を1ピクセルパディングしてからプーリングを行います。

複数のパラメータを指定

import torch
import torch.nn as nn

# 入力データを作成
input_tensor = torch.randn(1, 3, 32, 32)

# 畳み込み層とプーリング層を定義
conv = nn.Conv2d(3, 16, kernel_size=(3, 5), stride=(2, 1), padding=(1, 2))
max_pool = nn.MaxPool2d(kernel_size=(2, 3), stride=(1, 2), dilation=(2, 1))

# 畳み込み層とプーリング層を順に適用
output = conv(input_tensor)
output = max_pool(output)

# 出力テンソルのサイズを確認
print(output.size())
torch.Size([1, 16, 14, 15])
  • dilation=(2, 1): 拡張率は高さ方向に2、幅方向に1になります。
  • padding=(1, 2): パディングは高さ方向に1、幅方向に2になります。
  • stride=(2, 1): ストライドは高さ方向に2、幅方向に1になります。
  • kernel_size=(3, 5): カーネルサイズは高さ方向に3、幅方向に5になります。

これらのパラメータを組み合わせることで、さまざまなプーリング操作を実現できます。

import torch
import torch.nn as nn

# 入力データを作成
input_tensor = torch.randn(1, 3, 28, 28)

# プーリング層を定義
max_pool = nn.MaxPool2d(kernel_size=2, return_indices=True)

# プーリングを実行
output, indices = max_pool(input_tensor)

# 出力テンソルとインデックスを確認
print(output.size())
print(indices.size())


平均プーリング (Average Pooling)

torch.nn.AvgPool2dは、入力窓内の平均値を計算するプーリング層です。MaxPool2dと同様に、画像サイズを縮小しながら特徴量を抽出することができます。

利点

  • 局所的な特徴だけでなく、全体的な傾向も捉えることができる
  • ノイズの影響を受けにくい

欠点

  • MaxPool2dよりも識別精度が低くなる場合がある
  • 最大値の特徴を失ってしまう可能性がある


import torch
import torch.nn as nn

# 入力データを作成
input_tensor = torch.randn(1, 3, 28, 28)

# 平均プーリング層を定義
avg_pool = nn.AvgPool2d(kernel_size=2)

# 平均プーリングを実行
output = avg_pool(input_tensor)

# 出力テンソルのサイズを確認
print(output.size())

L1プーリング (L1 Pooling)

torch.nn.L1Pool2dは、入力窓内の絶対値の総和を計算するプーリング層です。MaxPool2dAvgPool2dとは異なり、スパースな表現を生成することができます。

利点

  • ロバスト性が高い
  • スパースな表現を生成できる

欠点

  • 他プーリング層よりも識別精度が低くなる場合がある
  • 計算量が多い


import torch
import torch.nn as nn

# 入力データを作成
input_tensor = torch.randn(1, 3, 28, 28)

# L1プーリング層を定義
l1_pool = nn.L1Pool2d(kernel_size=2)

# L1プーリングを実行
output = l1_pool(input_tensor)

# 出力テンソルのサイズを確認
print(output.size())

ガウスプーリング (Gaussian Pooling)

GaussianPool2dは、ガウス関数で重み付けされた平均値を計算するプーリング層です。周辺のピクセルほど重みが大きくなるように計算されます。

利点

  • ノイズの影響を受けにくい
  • 周辺の情報も考慮したプーリングが可能

欠点

  • 実装が複雑
  • 計算量が多い


import torch
import torch.nn as nn
from torchvision.transforms import GaussianBlur

# 入力データを作成
input_tensor = torch.randn(1, 3, 28, 28)

# ガウスフィルタを作成
gaussian_blur = GaussianBlur(kernel_size=2, sigma=(0.5, 0.5))

# ガウスフィルタを適用
blurred_input = gaussian_blur(input_tensor)

# 畳み込み層とプーリング層を定義
conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
avg_pool = nn.AvgPool2d(kernel_size=2)

# 畳み込み層とプーリング層を順に適用
output = conv(blurred_input)
output = avg_pool(output)

# 出力テンソルのサイズを確認
print(output.size())

空間アテンション (Spatial Attention)

空間アテンションは、入力テンソルの各位置に対して重要度を計算し、重要な部分のみを抽出する手法です。プーリング層とは異なり、柔軟な特徴量抽出が可能になります。

利点

  • 入力画像の構造を考慮した特徴量抽出が可能
  • 重要な部分のみを抽出できる

欠点

  • 実装が複雑
  • 計算量が多い
import torch
import torch.nn as nn
from torch.nn import functional as F

# 入力データを作成
input_tensor = torch.randn(1, 3, 28, 28)

# 空間アテンションモジュールを定義
class SpatialAttentionModule(nn.Module):
    def __init__(