コイン投げからサイコロまで! PyTorch Categoricalディストリビューションでカテゴリカル変数を自在に扱う


このチュートリアルでは、torch.distributions.categorical.Categorical の基本的な使い方と、プログラミングにおける具体的な応用例について解説します。

Categorical の基本

Categorical ディストリビューションは、確率パラメータ probs または logits を用いて初期化されます。

  • logits は、各カテゴリの対数オッズを表すベクトルです。
  • probs は、各カテゴリの発生確率を表すベクトルです。要素の合計は 1 でなければなりません。
import torch
import torch.distributions as dist

probs = torch.tensor([0.2, 0.3, 0.5])
categorical = dist.Categorical(probs)

logits = torch.tensor([1.0, 2.0, 3.0])
categorical = dist.Categorical(logits=logits)

サンプリング

Categorical ディストリビューションからサンプリングするには、sample() メソッドを用います。このメソッドは、確率に基づいてランダムなカテゴリを生成します。

sample = categorical.sample()
print(sample)  # tensor(2)

引数 sample_shape を指定することで、複数のサンプルを同時に生成できます。

samples = categorical.sample((5,))
print(samples)  # tensor([2, 0, 1, 2, 1])

確率計算

Categorical ディストリビューションは、各カテゴリにおける事象発生確率を計算することができます。

log_prob = categorical.log_prob(1)
print(log_prob)  # tensor(-1.3862941040922573)

log_prob() メソッドは、指定されたカテゴリにおける事象発生確率の対数オッズを返します。

エントロピー

Categorical ディストリビューションのエントロピーを計算することができます。エントロピーは、分布の不確実性を表す指標です。

entropy = categorical.entropy()
print(entropy)  # tensor(1.0986138056640625)

Categorical ディストリビューションは、様々な場面で役立ちます。以下に、その例をいくつか挙げます。

  • テキスト生成
    次のような単語列生成モデルにおける、次の単語の出現確率を表現することができます。

"私は猫が好きです。猫はかわいいです。"

  • ゲーム
    サイコロの目やトランプのカードを引くような、ランダムな事象をシミュレートすることができます。

  • 推薦システム
    ユーザーの過去の行動に基づいて、次にどの商品を推薦するかを決定することができます。



コイン投げシミュレーション

以下のコードは、コイン投げをシミュレートする例です。

import torch
import torch.distributions as dist

# コインの裏表を表すカテゴリ
heads = 0
tails = 1

# 確率パラメータ (裏が出る確率が 0.7)
probs = torch.tensor([0.3, 0.7])
categorical = dist.Categorical(probs)

# 10回コインを投げる
n_trials = 10
results = categorical.sample((n_trials,))

# 結果のカウント
heads_count = (results == heads).sum().item()
tails_count = (results == tails).sum().item()

print(f"裏が出た回数: {heads_count}")
print(f"表が出た回数: {tails_count}")

このコードを実行すると、以下の出力が得られます。

裏が出た回数: 7
表が出た回数: 3

サイコロの目

以下のコードは、サイコロの目をシミュレートする例です。

import torch
import torch.distributions as dist

# サイコロの目の値を表すカテゴリ
categories = torch.arange(1, 7)

# 各目の出現確率が等しい
probs = torch.ones(len(categories)) / len(categories)
categorical = dist.Categorical(probs)

# 10回サイコロを振る
n_trials = 10
results = categorical.sample((n_trials,))

# 結果のヒストグラム
print(results.hist())
tensor([4, 1, 2, 1, 2, 0, 0, 0, 0, 0])

上記は、10 回サイコロを振った結果のヒストグラムです。各目の出現回数がほぼ均等であることが確認できます。

以下のコードは、カテゴリカル変数からサンプリングする例です。

import torch
import torch.distributions as dist

# カテゴリ
categories = ["cat", "dog", "bird"]

# 各カテゴリの出現確率
probs = torch.tensor([0.25, 0.5, 0.25])
categorical = dist.Categorical(probs)

# 10 回サンプリング
n_samples = 10
samples = categorical.sample((n_samples,))

# 結果の表示
for sample in samples:
    print(categories[sample.item()])
dog
dog
cat
dog
dog
bird
cat
cat
dog
cat

上記は、10 回サンプリングした結果です。各カテゴリが出現する確率が、指定した確率 probs に従っていることが確認できます。



torch.nn.functional.cross_entropy()

torch.nn.functional.cross_entropy() は、クロスエントロピー損失関数を計算する関数です。これは、カテゴリカル分布の確率質量関数 (PMF) と同等の計算を行います。

import torch
import torch.nn.functional as F

# 予測値 (logits)
logits = torch.tensor([1.0, 2.0, 3.0])

# 真のラベル
target = torch.tensor(2)

# クロスエントロピー損失
loss = F.cross_entropy(logits, target)
print(loss)  # tensor(0.4054)

このコードでは、logitstarget を入力として、クロスエントロピー損失を計算しています。logits は、各カテゴリの対数オッズを表すベクトルであり、target は真のラベルを表すスカラ値です。

torch.distributions.OneHotCategorical()

torch.distributions.OneHotCategorical() は、ワンホットベクトル表現のカテゴリカル分布です。これは、各カテゴリを一意に識別するベクトルを生成します。

import torch
import torch.distributions as dist

# 確率パラメータ (probs)
probs = torch.tensor([0.2, 0.3, 0.5])
categorical = dist.OneHotCategorical(probs)

# サンプリング
sample = categorical.sample()
print(sample)  # tensor([0, 0, 1])

このコードでは、probs を入力として、ワンホットベクトル表現のカテゴリカル分布を生成しています。sample() メソッドを用いて、ランダムなカテゴリをサンプリングすることができます。生成されたサンプルは、各カテゴリが 0 または 1 の値を持つベクトルになります。

手動実装

より柔軟な制御が必要な場合は、torch.distributions.categorical.Categorical を手動で実装することもできます。

import torch

def categorical_pmf(logits, value):
    """カテゴリカル分布の確率質量関数 (PMF) を計算する関数。

    Args:
        logits (torch.Tensor): 各カテゴリの対数オッズを表すベクトル。
        value (int): 確率を求めるカテゴリの値。

    Returns:
        torch.Tensor: 指定されたカテゴリの確率。
    """
    max_value = logits.max(dim=1)[0]
    unnormalized_prob = torch.exp(logits - max_value)
    return unnormalized_prob[value] / unnormalized_prob.sum()

# 確率パラメータ (logits)
logits = torch.tensor([1.0, 2.0, 3.0])

# 確率を求めるカテゴリ
value = 2

# PMF の計算
prob = categorical_pmf(logits, value)
print(prob)  # tensor(0.367879441171875)

このコードでは、categorical_pmf() 関数を定義して、カテゴリカル分布の確率質量関数 (PMF) を計算しています。この関数は、logitsvalue を入力として、指定されたカテゴリの確率を返します。

NumPyro や JAX などの他の確率プログラミングライブラリは、Categorical ディストリビューションを含む様々な確率分布を提供しています。これらのライブラリは、PyTorch とは異なる機能や API を提供している場合があります。

torch.distributions.categorical.Categorical は、カテゴリカル分布を扱うための強力なツールです。しかし、状況によっては、他の代替方法の方が適している場合があります。