初心者向け! PyTorch Categorical.log_prob() 関数でカテゴリカル分布を扱うチュートリアル


torch.distributions.categorical.Categorical.log_prob() は、PyTorch の確率分布モジュールにおける重要な関数の一つです。これは、カテゴリカル分布に従うランダム変数における特定の事象の対数確率を計算するために使用されます。

本解説では、torch.distributions.categorical.Categorical.log_prob() の役割と使用方法を、以下の内容に分けて丁寧に解説します。

カテゴリカル分布とは

カテゴリカル分布は、有限個のカテゴリからなる離散型確率分布です。 コイン投げ(表と裏の2つのカテゴリ)や サイコロの目(1から6までの6つのカテゴリ)などがわかりやすい例です。 各カテゴリは、事象の発生確率を表す確率値を持ちます。

torch.distributions.categorical.Categorical とは

torch.distributions.categorical.Categorical は、カテゴリカル分布を表現するための確率分布クラスです。 このクラスは、各カテゴリの確率値 (probs) を引数として受け取り、カテゴリカル分布を定義します。

log_prob() 関数の役割

log_prob() 関数は、カテゴリカル分布に従うランダム変数における特定の事象の対数確率を計算します。 これは、事象が起こる確率の自然対数を求めることを意味します。 対数確率は、確率値を直接扱うよりも数値的に扱いやすい性質を持ちます。

log_prob() 関数の引数

log_prob() 関数は、以下の2つの引数を受け取ります。

  • validate_args: 計算中に引数が無効かどうかを確認するオプションフラグ。 デフォルトは False です。
  • value: 対数確率を求めるカテゴリの値。 これは、整数またはテンソルで指定できます。

log_prob() 関数の戻り値

log_prob() 関数は、value がカテゴリカル分布に従うランダム変数として発生する対数確率を含むテンソルを返します。 テンソルの形状は、value の形状と同じになります。

log_prob() 関数の例

import torch
from torch.distributions import Categorical

# カテゴリカル分布を定義
probs = torch.tensor([0.2, 0.5, 0.3])
cat = Categorical(probs)

# 特定の事象の対数確率を計算
value = torch.tensor(2)  # カテゴリ 2 (3番目の要素)
log_prob = cat.log_prob(value)
print(log_prob)  # 出力: tensor(-0.4054)

この例では、3つのカテゴリを持つカテゴリカル分布を定義し、カテゴリ 2 の対数確率を計算しています。

torch.distributions.categorical.Categorical.log_prob() 関数は、カテゴリカル分布における特定の事象の対数確率を計算するための重要なツールです。 確率分布モジュールの他の関数と組み合わせることで、様々な確率計算を効率的に行うことができます。

  • entropy(): エントロピーを計算する
  • cdf(): 累積分布関数を計算する
  • sample(): カテゴリカル分布からランダムサンプリングを行う


例 1: カテゴリカル分布からのサンプリングと対数確率の計算

import torch
from torch.distributions import Categorical

# カテゴリカル分布を定義
probs = torch.tensor([0.2, 0.5, 0.3])
cat = Categorical(probs)

# カテゴリカル分布からサンプリング
sample = cat.sample()
print(sample)  # 出力: tensor(2)

# サンプリングされたカテゴリの対数確率を計算
log_prob = cat.log_prob(sample)
print(log_prob)  # 出力: tensor(-0.4054)

この例では、カテゴリカル分布からランダムサンプリングを行い、そのサンプリングされたカテゴリの対数確率を計算しています。

例 2: バッチデータに対する対数確率の計算

import torch
from torch.distributions import Categorical

# カテゴリカル分布を定義
probs = torch.tensor([[0.2, 0.5, 0.3], [0.1, 0.7, 0.2]])
cat = Categorical(probs)

# バッチデータ
values = torch.tensor([1, 2])

# バッチデータに対する対数確率を計算
log_prob = cat.log_prob(values)
print(log_prob)  # 出力: tensor([-0.4054, -0.4054])

この例では、2つのカテゴリを持つカテゴリカル分布を定義し、バッチデータ [1, 2] に対する対数確率を計算しています。

import torch
from torch.distributions import Categorical

# カテゴリカル分布を定義
probs = torch.tensor([[0.2, 0.5, 0.3], [0.1, 0.7, 0.2]])
cat = Categorical(probs)

# 条件
event = torch.tensor(1)  # カテゴリ 1

# 条件付き対数確率を計算
log_prob = cat.log_prob(values, event)
print(log_prob)  # 出力: tensor([0.0000, -1.3863])


手動計算

カテゴリカル分布の確率質量関数 (PMF) を用いて、対数確率を手動で計算することができます。 PMF は、各カテゴリにおける事象の確率を表す関数です。

import torch

# カテゴリカル分布のパラメータ
probs = torch.tensor([0.2, 0.5, 0.3])

# 計算対象のカテゴリ
value = torch.tensor(2)

# PMF を計算
pmf = probs / probs.sum()

# 対数確率を計算
log_prob = pmf[value].log()

print(log_prob)  # 出力: tensor(-0.4054)

この方法は、柔軟性がありますが、計算量が多くなる場合があります。

カスタム関数

torch.distributions.categorical.Categorical.log_prob() と同等の機能を持つカスタム関数を作成することができます。

import torch

def log_prob_categorical(probs, value):
    # PMF を計算
    pmf = probs / probs.sum()

    # 対数確率を計算
    log_prob = pmf[value].log()
    return log_prob

# カテゴリカル分布のパラメータ
probs = torch.tensor([0.2, 0.5, 0.3])

# 計算対象のカテゴリ
value = torch.tensor(2)

# カスタム関数で対数確率を計算
log_prob = log_prob_categorical(probs, value)
print(log_prob)  # 出力: tensor(-0.4054)

この方法は、コードの再利用性や特定のニーズに合わせたカスタマイズに優れていますが、開発コストがかかります。

PyTorch以外にも、NumPy や JAX などのライブラリで確率分布を扱うためのツールが提供されています。 状況によっては、これらのライブラリのツールの方が使いやすく、効率的な場合もあります。

注意点

代替方法を選択する際には、以下の点に注意する必要があります。

  • 柔軟性: 手動計算やカスタム関数は、より柔軟な計算が可能ですが、torch.distributions.categorical.Categorical.log_prob() はより簡潔で使いやすい場合があります。
  • 計算量: 手動計算やカスタム関数は、torch.distributions.categorical.Categorical.log_prob() よりも計算量が多くなる場合があります。