コインの裏表をシミュレート?PyTorchで幾何分布を扱う「torch.distributions.geometric.Geometric」


torch.distributions.geometric.Geometric は、PyTorch の Probability Distributions モジュールで提供される幾何分布を表すクラスです。幾何分布は、ベルヌーイ試行において、初めて成功するまでに必要な試行回数(失敗回数 + 1)をモデル化する確率分布です。

このチュートリアルでは、以下の内容を解説します

  • 例:コインの裏表を繰り返す試行
  • エンタロピーと情報量
  • 確率密度関数と累積分布関数
  • サンプルの生成
  • パラメータとサポート

torch.distributions.geometric.Geometric は、以下の引数を受け取ります。

  • logits: 成功確率のロジット (対数オッズ) を表すテンソルまたはスカラー
  • probs: 成功確率 (0 ~ 1) を表すテンソルまたはスカラー

パラメータとサポート

  • support: 非負整数 (0, 1, 2, ...)
  • logits: 実数値である必要があります。
  • probs: 0 から 1 までの範囲にある必要があります。

サンプルの生成

sample() メソッドを使用して、幾何分布からのサンプルを生成できます。

import torch
from torch.distributions import Geometric

probs = torch.tensor([0.3])
g = Geometric(probs)
sample = g.sample()
print(sample)

出力

tensor([2])

確率密度関数と累積分布関数

log_prob() メソッドを使用して、特定の事象の確率密度関数を計算できます。

log_prob = g.log_prob(2)
print(log_prob)

出力

tensor(-1.3862)

cdf() メソッドを使用して、特定の事象以下の累積分布関数を計算できます。

cdf = g.cdf(2)
print(cdf)

出力

tensor(0.18)

エンタロピーと情報量

entropy() メソッドを使用して、分布のエンタロピーを計算できます。

entropy = g.entropy()
print(entropy)

出力

tensor(1.0986)

information_content() メソッドを使用して、分布の情報量を計算できます。

information_content = g.information_content()
print(information_content)

出力

tensor(1.0986)

例:コインの裏表を繰り返す試行

コインの裏表を繰り返す試行において、初めて表が出るまでに必要な裏面の数をモデル化してみましょう。

import torch
from torch.distributions import Geometric

probs = torch.tensor([0.5])
g = Geometric(probs)

# 10 回試行
n_trials = 10
samples = g.sample(n_trials)

# 結果のヒストグラムを表示
import matplotlib.pyplot as plt

plt.hist(samples.numpy())
plt.xlabel("Number of failures before success")
plt.ylabel("Count")
plt.title("Histogram of coin toss example")
plt.show()

この例では、コインの裏表を 10 回繰り返す試行をシミュレートし、初めて表が出るまでの裏面の数をヒストグラムで表示しています。



import torch
from torch.distributions import Geometric

# パラメータの設定
probs = torch.tensor([0.3])

# Geometric分布の生成
g = Geometric(probs)

# サンプルの生成
sample = g.sample()
print(f"サンプリング結果: {sample}")

# 確率密度関数の計算
log_prob = g.log_prob(2)
print(f"確率密度関数 (x = 2): {log_prob}")

# 累積分布関数の計算
cdf = g.cdf(2)
print(f"累積分布関数 (x <= 2): {cdf}")

# エンタロピーの計算
entropy = g.entropy()
print(f"エンタロピー: {entropy}")

# 情報量の計算
information_content = g.information_content()
print(f"情報量: {information_content}")
  • information_content() メソッドを使用して、分布の情報量を計算します。
  • entropy() メソッドを使用して、分布のエンタロピーを計算します。
  • cdf() メソッドを使用して、x <= 2 における累積分布関数を計算します。
  • log_prob() メソッドを使用して、x = 2 における確率密度関数を計算します。
  • sample() メソッドを使用して、分布からサンプルを生成します。
  • Geometric クラスを使用して、このパラメータに基づいて幾何分布を生成します。
  • パラメータ probs として、成功確率 0.3 を設定します。


手動実装

最も基本的な代替方法は、幾何分布の確率密度関数と累積分布関数を自分で実装することです。これは、シンプルなケースであれば有効ですが、複雑な場合やコードの再利用性を高めたい場合は、非効率的になる可能性があります。

import torch

def geometric_pmf(k, p):
    return (1 - p) ** k * p

def geometric_cdf(k, p):
    return 1 - (1 - p) ** (k + 1)

# 例:確率密度関数と累積分布関数の計算

probs = torch.tensor([0.3])
k = torch.tensor(2)

pmf = geometric_pmf(k, probs)
print(f"確率密度関数 (x = 2): {pmf}")

cdf = geometric_cdf(k, probs)
print(f"累積分布関数 (x <= 2): {cdf}")

利点

  • 計算効率を最大限に高めることができます。
  • コードを完全に制御できます。

欠点

  • バグが発生しやすい可能性があります。
  • 複雑な場合やコードの再利用性を高めたい場合は、非効率的になります。

torch.distributions.Bernoulli を用いたカスタム分布

torch.distributions.Bernoulli を用いて、カスタム分布を作成することもできます。この方法は、torch.distributions.geometric.Geometric よりも柔軟性が高く、複雑なモデル化にも対応できます。

import torch
from torch.distributions import Bernoulli

def geometric_custom(probs):
    g = Bernoulli(probs)
    def sample():
        success = False
        count = 0
        while not success:
            trial = g.sample()
            success = trial == 1
            count += 1
        return count
    return sample

# 例:サンプリング

probs = torch.tensor([0.3])
custom_geometric = geometric_custom(probs)
sample = custom_geometric()
print(f"サンプリング結果: {sample}")

利点

  • torch.distributions モジュールの機能を活用できます。
  • 柔軟性が高く、複雑なモデル化にも対応できます。

欠点

  • torch.distributions.geometric.Geometric よりも計算効率が低くなる場合があります。
  • コードが冗長になる可能性があります。

サードパーティライブラリ

scipystatsmodels などのサードパーティライブラリを使用することもできます。これらのライブラリは、torch.distributions.geometric.Geometricと同等の機能を提供しているだけでなく、追加の統計関数を提供している場合があります。

import scipy.stats as stats

# 例:確率密度関数と累積分布関数の計算

probs = 0.3
k = 2

pmf = stats.geom.pmf(k, probs)
print(f"確率密度関数 (x = 2): {pmf}")

cdf = stats.geom.cdf(k, probs)
print(f"累積分布関数 (x <= 2): {cdf}")

利点

  • コードが簡潔になります。
  • 統計関数の豊富なライブラリを利用できます。

欠点

  • 追加のライブラリをインストールする必要があります。
  • PyTorch との統合が制限されている場合があります。

最適な代替方法の選択

最適な代替方法は、具体的なニーズと要件によって異なります。

  • 豊富な統計関数と簡潔なコードを必要とする場合は、サードパーティライブラリが適しています。
  • 柔軟性と汎用性が必要な場合は、torch.distributions.Bernoulli を用いたカスタム分布が適しています。
  • シンプルで効率的な方法が必要な場合は、手動実装が適しています。