PyTorchでワンランク上のプログラミング:OneHotCategorical分布のparam_shape属性を駆使


torch.distributions.one_hot_categorical.OneHotCategorical.param_shape は、OneHotCategorical 分布のパラメータの形状を表す属性です。これは、分布を定義するために必要な入力データの次元数を示します。

数学的定義

param_shape は、OneHotCategorical 分布の確率パラメータ probs または logits の形状を表すタプルです。

  • logits が指定されている場合: param_shape = logits.shape[:-1]
  • probs が指定されている場合: param_shape = probs.shape[:-1]

以下の例では、probs を使用して OneHotCategorical 分布を定義しています。

import torch
from torch.distributions import OneHotCategorical

probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)

print(categorical.param_shape)

このコードを実行すると、以下の出力が得られます。

torch.Size([3])

これは、probs テンソルが形状 [3] であることを意味し、分布は 3 つのカテゴリを持つことを示します。

解釈

param_shape 属性は、OneHotCategorical 分布を定義するために必要な情報量を理解するのに役立ちます。これは、分布をサンプリングしたり、確率を計算したりする際に必要な入力データの形状を決定するために使用できます。

  • param_shape 属性は、OneHotCategorical 分布だけでなく、他の torch.distributions モジュールの分布でも使用できます。


probs を使用して OneHotCategorical 分布を定義する

import torch
from torch.distributions import OneHotCategorical

probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)

print(categorical.param_shape)  # 出力: torch.Size([3])

logits を使用して OneHotCategorical 分布を定義する

import torch
from torch.distributions import OneHotCategorical

logits = torch.tensor([0., 0., 1.])
categorical = OneHotCategorical(logits=logits)

print(categorical.param_shape)  # 出力: torch.Size([3])

param_shape を使用して分布をサンプリングする

import torch
from torch.distributions import OneHotCategorical

probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)

sample = categorical.sample((2, 5))  # サンプル形状: (2, 5)

print(sample)
import torch
from torch.distributions import OneHotCategorical

probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)

event = torch.tensor([2, 0, 1])  # イベント
log_prob = categorical.log_prob(event)  # 対数確率

print(log_prob)

上記のコード例は、param_shape 属性をどのように使用して OneHotCategorical 分布を定義、サンプリング、確率計算に使用できるかを示しています。

  1. と 2. の例では、probslogits のいずれを使用して OneHotCategorical 分布を定義し、param_shape 属性を使用して分布のパラメータの形状を出力します。

  2. の例では、param_shape を使用して分布からサンプルを生成します。サンプルの形状は、param_shape の最初の要素 (3) と指定されたサンプルのバッチサイズ (2) の積になります。

  3. の例では、param_shape を使用して特定のイベントの確率を計算します。log_prob テンソルの形状は、param_shape の最初の要素 (3) とイベントのバッチサイズ (3) と同じになります。

これらの例は、param_shape 属性が OneHotCategorical 分布を操作する際に役立つツールであることを示しています。

  • コードを実行するには、torch パッケージをインストールする必要があります。
  • 上記のコード例は、PyTorch 1.12.1 でテストされています。


代替方法

  1. probs または logits の形状を確認する

    • probs が指定されている場合: probs.shape[:-1]
    • logits が指定されている場合: logits.shape[:-1]

    この方法は、param_shape 属性とほぼ同じ結果を返すシンプルな方法です。

  2. distribution.event_shape 属性を使用する

    • distribution.event_shape は、分布の確率イベントの形状を表します。OneHotCategorical 分布の場合、これは常に [1] です。

    この方法は、分布のパラメータの形状ではなく、確率イベントの形状を取得する場合に役立ちます。

import torch
from torch.distributions import OneHotCategorical

probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)

# 代替方法 1
param_shape_1 = probs.shape[:-1]
print(param_shape_1)  # 出力: torch.Size([3])

param_shape_2 = logits.shape[:-1]
print(param_shape_2)  # 出力: torch.Size([3])

# 代替方法 2
event_shape = categorical.event_shape
print(event_shape)  # 出力: torch.Size([1])

考察

  • 代替方法 2 は、分布のパラメータの形状ではなく、確率イベントの形状を取得する場合に役立ちます。
  • 代替方法 1 は、param_shape 属性とほぼ同じ結果を返すシンプルでわかりやすい方法です。
  • param_shape 属性は、OneHotCategorical 分布のパラメータの形状を取得する最も直接的な方法です。

状況に応じて、どの方法が最適かを選択します。