【超解説】PyTorchにおける3D平均プーリングのすべて ~『torch.nn.functional.avg_pool3d』を使いこなして、畳み込みニューラルネットワークをレベルアップ!~


torch.nn.functional.avg_pool3dは、PyTorchのnn.functionalモジュールにある3D平均プーリング関数を用いる関数です。この関数は、入力テンソルを指定されたサイズで分割し、各領域の平均値を計算して出力テンソルを生成します。畳み込みニューラルネットワーク(CNN)において、特徴マップの次元を縮小したり、ノイズを低減したりするために使用されます。

引数

  • dtype: 出力テンソルのデータ型
  • ceil_mode: 天井モードフラグ。Trueの場合、出力テンソルの形状は入力テンソルの形状をceilで計算したものになります。Falseの場合、floorで計算されます。
  • padding: パディング。形状は(padding_d, padding_h, padding_w)である必要があり、ここで:
    • padding_dは深度方向パディング
    • padding_hは高さ方向パディング
    • padding_wは幅方向パディング
  • stride: プーリングカーネルのストライド。形状は(stride_d, stride_h, stride_w)である必要があり、ここで:
    • stride_dはプーリングカーネルの深度方向ストライド
    • stride_hはプーリングカーネルの高さ方向ストライド
    • stride_wはプーリングカーネルの幅方向ストライド
  • kernel_size: プーリングカーネルのサイズ。形状は(kernel_size_d, kernel_size_h, kernel_size_w)である必要があり、ここで:
    • kernel_size_dはプーリングカーネルの深度
    • kernel_size_hはプーリングカーネルの高さ
    • kernel_size_wはプーリングカーネルの幅
  • input: 入力テンソル(4Dテンソル)。形状は(N, C, D, H, W)である必要があり、ここで:
    • Nはバッチサイズ
    • Cは入力チャネル数
    • Dは入力テンソルの深度
    • Hは入力テンソルの高さ
    • Wは入力テンソルの幅

戻り値

出力テンソル(4Dテンソル)。形状は(N, C, oD, oH, oW)です。ここで:

  • oWは出力テンソルの幅
  • oHは出力テンソルの高さ
  • oDは出力テンソルの深度

import torch
import torch.nn.functional as F

input = torch.randn(2, 3, 4, 5, 5)
output = F.avg_pool3d(input, kernel_size=2, stride=1, padding=1)
print(output.size())

この例では、形状が(2, 3, 4, 5, 5)の入力テンソルに対して、2x2x2のプーリングカーネルで平均プーリングを行い、ストライド1、パディング1で出力を生成します。出力テンソルの形状は(2, 3, 2, 3, 3)になります。

  • 平均プーリングは、最大プーリングよりもノイズに対してロバストですが、空間的な特徴情報を失いやすくなります。
  • プーリングカーネルのサイズ、ストライド、パディングを適切に設定することで、ネットワークの性能を調整することができます。
  • torch.nn.AvgPool3dモジュールを使用して、同様の機能を持つプーリング層を構築することもできます。モジュールの方が柔軟性が高く、属性を設定したり、他のモジュールと組み合わせたりすることができます。


import torch
import torch.nn.functional as F

# 入力データを作成
data = torch.randn(2, 3, 16, 16, 16)  # N, C, D, H, W の形状

# プーリングを実行
output = F.avg_pool3d(data, kernel_size=2, stride=2)  # 2x2x2 のプーリングカーネル、ストライド2

# 結果を出力
print(output.size())  # (2, 3, 8, 8, 8) の形状

このコードは以下の処理を実行します。

  1. ランダムな 3D ボリュームデータを作成します。データの形状は (2, 3, 16, 16, 16) で、ここで:
    • 2 はバッチサイズ
    • 3 は入力チャネル数
    • 16 は入力テンソルの深度、高さ、幅
  2. F.avg_pool3d 関数を使用して、入力データに対して 2x2x2 のプーリングカーネルで平均プーリングを実行します。ストライドは 2 に設定されているため、出力テンソルの空間寸法は半分になります。
  3. プーリングの結果を出力します。出力テンソルの形状は (2, 3, 8, 8, 8) になります。

このコードは、基本的な例であり、状況に応じて拡張することができます。例えば、以下のように変更することができます。

  • 出力テンソルのデータ型を指定する
  • 天井モードフラグを設定する
  • パディングを設定する
  • 異なるサイズのプーリングカーネルとストライドを使用する


torch.nn.AvgPool3dモジュール

  • 欠点:
    • torch.nn.functional.avg_pool3dよりも冗長な記述になる場合があります。
  • 利点:
    • torch.nn.functional.avg_pool3dよりも柔軟性が高く、属性を設定したり、他のモジュールと組み合わせたりすることができます。
    • プーリング操作をネットワークのアーキテクチャの一部として定義することができます。
import torch
import torch.nn as nn

class MyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.avg_pool = nn.AvgPool3d(kernel_size=2, stride=1, padding=1)

    def forward(self, x):
        output = self.avg_pool(x)
        return output

カスタムカーネルを使用した畳み込み

  • 欠点:
    • torch.nn.functional.avg_pool3dよりも計算コストが高くなる場合があります。
    • コードが複雑になる可能性があります。
  • 利点:
    • プーリングカーネルの形状やウェイトを自由に設定することができます。
    • 特定のタスクに合わせたプーリング操作を設計することができます。
import torch
import torch.nn.functional as F

def custom_avg_pool3d(x, kernel_size, stride=1, padding=0):
    # カスタムカーネルを作成
    kernel = torch.ones(kernel_size, dtype=x.dtype, device=x.device) / kernel_size.numel()

    # 畳み込み操作を実行
    output = F.conv3d(x, kernel, stride=stride, padding=padding)

    return output
  • 欠点:

    • それぞれのプーリング関数の特性を理解する必要があります。
  • 利点:

    • 状況に応じて適切なプーリング関数を選択することで、ネットワークの性能を向上させることができます。
  • 各プーリング関数は、それぞれ異なる特性を持っています。

  • PyTorchには、最大プーリング、L2プーリング、ガウスプーリングなど、さまざまなプーリング関数があります。

    • 空間的な特徴情報を失う可能性があります。
    • 計算コストを削減することができます。
  • 状況によっては、3Dプーリングではなく、1Dまたは2Dプーリングを使用して次元を削減することができます。