PyTorchで確率分布を理解する:Bernoulli分布とlog_prob()関数


torch.distributions.bernoulli.Bernoulli.log_prob() は、PyTorchの確率分布モジュールにおけるBernoulli分布において、特定の値が生成される確率の対数を計算します。これは、ロジスティック回帰などの確率モデルにおいて、モデルが与えられたデータを出力する確率を評価するために使用されます。

引数

  • distribution: 対象となるBernoulli分布
  • value: 確率を求める値。0 または 1 のどちらかである必要があります。

返り値

  • log_prob: 特定の値が生成される確率の対数

数学的表現

Bernoulli分布の確率密度関数は以下の式で表されます。

p(x) = p^(x) * (1 - p)^(1 - x)

ここで、

  • p: 成功確率 (0 から 1 の間の値)
  • x: 0 または 1 の値

log_prob() 関数は、この確率密度関数の対数をとります。

log_prob(x) = log(p^(x) * (1 - p)^(1 - x))

以下のコードは、Bernoulli分布と log_prob() 関数を使用して、特定の値が生成される確率の対数を出力します。

import torch
from torch.distributions import Bernoulli

# 確率 0.75 の Bernoulli分布を作成
p = 0.75
distribution = Bernoulli(probs=p)

# 特定の値 (0 または 1) を指定
value = 1

# log_prob() 関数を使用して、特定の値が生成される確率の対数を出力
log_prob = distribution.log_prob(value)
print(log_prob)

このコードを実行すると、以下の出力が得られます。

0.8112781244591308

これは、value が 1 である場合の確率が約 0.8113 であることを意味します。

  • log_prob() 関数は、勾配計算をサポートしており、確率モデルの訓練に使用できます。
  • log_prob() 関数は、確率の対数を返すことに注意してください。生の確率値を取得するには、exp() 関数で対数を出力する必要があります。


import torch
from torch.distributions import Bernoulli

# 確率 0.75 の Bernoulli分布を作成
p = 0.75
distribution = Bernoulli(probs=p)

# 特定の値 (0 または 1) のリストを作成
values = [1, 0, 1, 0, 1]

# log_prob() 関数を使用して、各値が生成される確率の対数を出力
log_probs = distribution.log_prob(torch.tensor(values))
print(log_probs)

出力は以下のようになります。

tensor([ 0.8113, -0.2197,  0.8113, -0.2197,  0.8113])


手動計算

p(x) = p ^ x * (1 - p) ^ (1 - x)
  • p: 確率 (0 から 1 の間の値)
  • x: 0 または 1 の値

この式を使用して、特定の値 x における確率の対数 log_prob(x) を手動で計算することができます。

import torch

def log_prob_bernoulli(value, p):
  """
  Bernoulli分布における特定の値の対数確率を計算

  引数:
    value: 確率を求める値 (0 または 1)
    p: 確率 (0 から 1 の間の値)

  返り値:
    log_prob: 特定の値が生成される確率の対数
  """
  return torch.log(p ** value * (1 - p) ** (1 - value))

# 確率 0.75 の Bernoulli分布
p = 0.75

# 特定の値 (0 または 1)
value = 1

# 手動計算による log_prob
log_prob_manual = log_prob_bernoulli(value, p)
print(log_prob_manual)

このコードは、log_prob_bernoulli 関数を使用して、Bernoulli分布における特定の値の対数確率を計算します。

カスタム確率分布クラスの作成

PyTorch では、カスタム確率分布クラスを作成することもできます。これは、より複雑な確率分布をモデル化したい場合や、torch.distributions モジュールに存在しない機能を実装したい場合に役立ちます。

import torch
from torch.distributions import Distribution

class BernoulliLogProb(Distribution):
  """
  カスタム Bernoulli 分布クラス

  引数:
    probs: 確率 (0 から 1 の間の値)

  属性:
    probs: 確率 (0 から 1 の間の値)
  """

  def __init__(self, probs):
    super().__init__(batch_shape=probs.shape)
    self.probs = probs

  def log_prob(self, value):
    """
    特定の値の対数確率を計算

    引数:
      value: 確率を求める値 (0 または 1)

    返り値:
      log_prob: 特定の値が生成される確率の対数
    """
    return torch.log(self.probs ** value * (1 - self.probs) ** (1 - value))

# 確率 0.75 の Bernoulli分布
p = 0.75

# カスタム Bernoulli 分布クラスのインスタンスを作成
distribution = BernoulliLogProb(probs=p)

# 特定の値 (0 または 1)
value = 1

# カスタム分布による log_prob
log_prob_custom = distribution.log_prob(value)
print(log_prob_custom)

このコードは、BernoulliLogProb というカスタム確率分布クラスを作成し、log_prob() メソッドを使用して特定の値の対数確率を計算します。

PyTorch以外にも、確率分布を扱うためのライブラリは数多く存在します。例えば、NumPy や SciPy などは、Bernoulli分布を含む様々な確率分布に対する関数を提供しています。

import numpy as np

# 確率 0.75 の Bernoulli 分布
p = 0.75

# 特定の値 (0 または 1)
value = 1

# NumPy による log_prob
log_prob_numpy = np.log(p ** value * (1 - p) ** (1 - value))
print(log_prob_numpy)

このコードは、NumPyを使用して、Bernoulli分布における特定の値の対数確率を計算します。

どの方法を選択すべきか

最適な方法は、状況によって異なります。

  • 柔軟性: カスタム確率分布クラスは最も柔軟性がありますが、コード量
  • シンプルさ: 手動計算は最もシンプルですが、計算量が多くなります。