【初心者向け】PyTorchで指数分布を扱う:`torch.distributions.exponential.Exponential.log_prob()`を徹底解説


指数分布とは?

指数分布は、ある事象が発生するまでの待ち時間を表す確率分布です。例えば、コールセンターへの電話待ち時間や、機械の故障までの時間間隔などが指数分布に従う場合があります。

指数分布の確率密度関数は以下の式で表されます。

p(x) = rate * exp(-rate * x)  # rate > 0, x >= 0

ここで、

  • x は ランダム変数の値、つまり待ち時間などを表します。
  • rateレートと呼ばれるパラメータであり、事象発生の頻度を表します。

torch.distributions.exponential.Exponential.log_prob() の役割

torch.distributions.exponential.Exponential.log_prob() は、上記で示した確率密度関数の対数値を計算します。

具体的には、入力された値 x が指数分布に従う確率 p(x) の対数 log(p(x)) を返します。

この対数確率は、モデルのパラメータ推定や、異常検知などのタスクで役立ちます。

torch.distributions.exponential.Exponential.log_prob()は以下の通りに使用できます。

import torch
from torch.distributions import Exponential

# レートを指定して指数分布を作成
distribution = Exponential(rate=0.5)

# 確率密度関数の対数値を計算
x = torch.tensor([1.0, 2.0, 3.0])
log_prob = distribution.log_prob(x)
print(log_prob)

このコードを実行すると、以下のような出力が得られます。

tensor([-0.5, -1.0, -1.5])

これは、入力された値 x = [1.0, 2.0, 3.0] が指数分布に従う確率の対数値を表しています。

torch.distributions.exponential.Exponential.log_prob() は、指数分布に従うランダム変数の確率密度関数の対数値を計算する関数です。

  • log_prob() は確率密度関数の対数値を返すため、生の確率値を得るには np.exp() で指数化が必要です。
  • torch.distributions.exponential.Exponential には、sample() メソッドを使ってランダムサンプリングを行うこともできます。


import torch
from torch.distributions import Exponential

# レートを指定して指数分布を作成
distribution = Exponential(rate=0.5)

# 乱数を生成
samples = distribution.sample((5,))
print(samples)

# 各サンプルの確率密度関数の対数値を計算
log_prob = distribution.log_prob(samples)
print(log_prob)
  1. レートが 0.5 の指数分布を作成します。
  2. sample((5,)) メソッドを使って、5個の乱数を生成します。
  3. 生成された乱数 samples の確率密度関数の対数値を計算します。
tensor([1.7839, 0.4054, 1.1091, 2.3026, 0.1271])
tensor([-0.8919, -1.6946, -1.0545, -0.8319, -2.0794])

1つ目の出力は、生成された5個の乱数の値です。

2つ目の出力は、各サンプルの確率密度関数の対数値です。

この例では、生成された乱数の値がそれぞれ異なっていることがわかります。

また、確率密度関数の対数値は、0に近いほど確率が高いことを示しています。

  • 生成された乱数はテンソル形式で返されます。
  • sample() メソッドは、指定された数の乱数を生成します。引数としてサイズを指定することで、生成する乱数の数をコントロールできます。


手動で計算する

指数分布の確率密度関数の対数は、以下の式で計算できます。

log_prob(x) = - rate * x
  • x は ランダム変数の値です。
  • rate は レートと呼ばれるパラメータです。

この式を自分で実装することで、torch.distributions.exponential.Exponential.log_prob() を代替することができます。

他のライブラリを使う

torch.distributions 以外にも、確率分布を扱うライブラリはいくつかあります。

例えば、以下のライブラリでは、指数分布の確率密度関数の対数値を計算する関数を提供しています。

これらのライブラリは、torch.distributions と異なる機能やインターフェースを提供している場合があります。

カスタム分布を作成する

torch.distributions では、カスタム分布を作成することもできます。

これは、既存の分布ではニーズを満たせない場合に役立ちます。

カスタム分布を作成するには、torch.distributions.Distribution クラスを継承する必要があります。

どの代替方法を選ぶべきかは、状況によって異なります。

以下の点を考慮する必要があります。

  • コミュニティ: どのライブラリに活発なコミュニティがあるか。
  • パフォーマンス: どのライブラリがパフォーマンスが良いか。
  • 使いやすさ: どのライブラリが使いやすいか。
  • 必要な機能: 必要な機能がどのライブラリで提供されているか。
  • 特別なニーズがある場合: カスタム分布を作成する。
  • TensorFlow Probability を使用している場合: TensorFlow Probability の指数分布関数を使用する。
  • NumPyro や JAX などのライブラリを既に使用している場合: それらのライブラリの指数分布関数を使用する。
  • シンプルで高速な方法が必要な場合: 手動で計算するか、torch.distributions.exponential.Exponential.log_prob() をそのまま使用する。