PyTorchでワンランク上のプログラミング:OneHotCategorical分布のparam_shape属性を駆使
torch.distributions.one_hot_categorical.OneHotCategorical.param_shape
は、OneHotCategorical
分布のパラメータの形状を表す属性です。これは、分布を定義するために必要な入力データの次元数を示します。
数学的定義
param_shape
は、OneHotCategorical
分布の確率パラメータ probs
または logits
の形状を表すタプルです。
logits
が指定されている場合:param_shape = logits.shape[:-1]
probs
が指定されている場合:param_shape = probs.shape[:-1]
例
以下の例では、probs
を使用して OneHotCategorical
分布を定義しています。
import torch
from torch.distributions import OneHotCategorical
probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)
print(categorical.param_shape)
このコードを実行すると、以下の出力が得られます。
torch.Size([3])
これは、probs
テンソルが形状 [3]
であることを意味し、分布は 3 つのカテゴリを持つことを示します。
解釈
param_shape
属性は、OneHotCategorical
分布を定義するために必要な情報量を理解するのに役立ちます。これは、分布をサンプリングしたり、確率を計算したりする際に必要な入力データの形状を決定するために使用できます。
param_shape
属性は、OneHotCategorical
分布だけでなく、他のtorch.distributions
モジュールの分布でも使用できます。
probs を使用して OneHotCategorical 分布を定義する
import torch
from torch.distributions import OneHotCategorical
probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)
print(categorical.param_shape) # 出力: torch.Size([3])
logits を使用して OneHotCategorical 分布を定義する
import torch
from torch.distributions import OneHotCategorical
logits = torch.tensor([0., 0., 1.])
categorical = OneHotCategorical(logits=logits)
print(categorical.param_shape) # 出力: torch.Size([3])
param_shape を使用して分布をサンプリングする
import torch
from torch.distributions import OneHotCategorical
probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)
sample = categorical.sample((2, 5)) # サンプル形状: (2, 5)
print(sample)
import torch
from torch.distributions import OneHotCategorical
probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)
event = torch.tensor([2, 0, 1]) # イベント
log_prob = categorical.log_prob(event) # 対数確率
print(log_prob)
上記のコード例は、param_shape
属性をどのように使用して OneHotCategorical
分布を定義、サンプリング、確率計算に使用できるかを示しています。
と 2. の例では、
probs
とlogits
のいずれを使用してOneHotCategorical
分布を定義し、param_shape
属性を使用して分布のパラメータの形状を出力します。の例では、
param_shape
を使用して分布からサンプルを生成します。サンプルの形状は、param_shape
の最初の要素 (3) と指定されたサンプルのバッチサイズ (2) の積になります。の例では、
param_shape
を使用して特定のイベントの確率を計算します。log_prob
テンソルの形状は、param_shape
の最初の要素 (3) とイベントのバッチサイズ (3) と同じになります。
これらの例は、param_shape
属性が OneHotCategorical
分布を操作する際に役立つツールであることを示しています。
- コードを実行するには、
torch
パッケージをインストールする必要があります。 - 上記のコード例は、PyTorch 1.12.1 でテストされています。
代替方法
probs または logits の形状を確認する
probs
が指定されている場合:probs.shape[:-1]
logits
が指定されている場合:logits.shape[:-1]
この方法は、
param_shape
属性とほぼ同じ結果を返すシンプルな方法です。distribution.event_shape 属性を使用する
distribution.event_shape
は、分布の確率イベントの形状を表します。OneHotCategorical
分布の場合、これは常に[1]
です。
この方法は、分布のパラメータの形状ではなく、確率イベントの形状を取得する場合に役立ちます。
例
import torch
from torch.distributions import OneHotCategorical
probs = torch.tensor([0.25, 0.25, 0.5])
categorical = OneHotCategorical(probs=probs)
# 代替方法 1
param_shape_1 = probs.shape[:-1]
print(param_shape_1) # 出力: torch.Size([3])
param_shape_2 = logits.shape[:-1]
print(param_shape_2) # 出力: torch.Size([3])
# 代替方法 2
event_shape = categorical.event_shape
print(event_shape) # 出力: torch.Size([1])
考察
- 代替方法 2 は、分布のパラメータの形状ではなく、確率イベントの形状を取得する場合に役立ちます。
- 代替方法 1 は、
param_shape
属性とほぼ同じ結果を返すシンプルでわかりやすい方法です。 param_shape
属性は、OneHotCategorical
分布のパラメータの形状を取得する最も直接的な方法です。
状況に応じて、どの方法が最適かを選択します。