確率分布ライブラリで実現する低ランク多変量正規分布からの乱数生成:`torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample()` のしくみ
- validate_args
入力パラメータの検証を行うかどうかを指定するブーリアン値です。デフォルトはFalse
です。 - sample_shape
生成する乱数のバッチサイズを表すタプルです。 - cov_diag
共分散行列の対角成分となるテンソルです。形状は(d,)
であり、d
は変数の数と同じです。 - cov_factor
低ランク共分散行列の因子となるテンソルです。形状は(d, r)
であり、ここでd
は変数の数、r
はランクを表します。
この関数は、以下の処理を実行します。
- 低ランク共分散行列
Σ
を計算します。
Sigma = torch.matmul(cov_factor, cov_factor.t()) + torch.diag(cov_diag)
- Cholesky分解を用いて、
Σ
の平方根L
を計算します。
L = torch.linalg.cholesky(Sigma)
- 標準正規乱数を生成します。
z = _standard_normal(sample_shape + L.shape)
- 生成した乱数を
L
で変換します。
x = torch.matmul(L, z)
- 変換された乱数を返します。
この関数は、以下の点に注意する必要があります。
validate_args
がTrue
の場合、入力パラメータの検証が行われます。sample_shape
はタプルである必要があります。cov_factor
とcov_diag
の形状が正しいことを確認する必要があります。
以下の例は、torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample()
関数を使用して、5 次元の多変量正規分布から 10 個の乱数を生成する方法を示します。
import torch
from torch.distributions import LowRankMultivariateNormal
# パラメータの設定
cov_factor = torch.randn(5, 2)
cov_diag = torch.rand(5)
sample_shape = (10,)
# 低ランク多変量正規分布の作成
mvnormal = LowRankMultivariateNormal(cov_factor, cov_diag)
# 乱数の生成
samples = mvnormal.rsample(sample_shape)
print(samples)
このコードを実行すると、以下のような出力が得られます。
tensor([[ 0.2360, 0.4184, 0.1321, -0.0894, 0.0039],
[ 0.7492, -1.0118, 0.7053, 0.2052, 0.0145],
[-0.1587, -0.3229, -0.5249, -0.2930, 0.5050],
[ 0.0951, -0.0542, 0.2281, 0.4918, -0.2493],
[-0.5043, -0.1575, 0.1222, 0.2300, -0.1195],
[ 0.4154, 0.5811, -0.1560, -0.0463, -0.2498],
[ 0.2079, -0.0478, 0.6974, 0.2070, -0.2007],
[-0.3049, -0.0752, -0.2854, -0.0584, 0.1472],
[ 0.5407, 0.0299, 0.0974, 0.4915, -0.
import torch
from torch.distributions import LowRankMultivariateNormal
# Define the parameters of the low-rank multivariate normal distribution
cov_factor = torch.randn(5, 2) # Covariance factor matrix
cov_diag = torch.rand(5) # Diagonal elements of the covariance matrix
sample_shape = (10,) # Batch size of the samples
# Create the low-rank multivariate normal distribution
mvnormal = LowRankMultivariateNormal(cov_factor, cov_diag)
# Generate random samples
samples = mvnormal.rsample(sample_shape)
# Print the samples
print(samples)
This code will generate 10 random samples from a 5-dimensional multivariate normal distribution with a low-rank covariance matrix. The covariance matrix is defined by the cov_factor
and cov_diag
parameters. The sample_shape
parameter specifies the batch size of the samples.
Here is a breakdown of the code:
torch
: The main PyTorch library.torch.distributions
: The PyTorch distributions module, which contains theLowRankMultivariateNormal
class.
Define the parameters of the low-rank multivariate normal distribution
cov_factor
: A tensor of shape(d, r)
whered
is the number of variables andr
is the rank of the covariance matrix. Thecov_factor
matrix defines the low-rank part of the covariance matrix.cov_diag
: A tensor of shape(d,)
whered
is the number of variables. Thecov_diag
tensor defines the diagonal elements of the covariance matrix.sample_shape
: A tuple of integers specifying the batch size of the samples.
Create the low-rank multivariate normal distribution
- An instance of the
LowRankMultivariateNormal
class is created using thecov_factor
,cov_diag
, andsample_shape
parameters.
- An instance of the
Generate random samples
- The
rsample()
method of themvnormal
object is called to generate random samples from the low-rank multivariate normal distribution. The samples are returned as a tensor of shape(sample_shape, d)
.
- The
Print the samples
- The
samples
tensor is printed to the console.
- The
torch.mvnormal() 関数を使用する
torch.mvnormal()
関数は、標準の多変量正規分布から乱数を生成するために使用できます。この関数は、以下のパラメータを受け取ります。
- sample_shape
生成する乱数のバッチサイズを表すタプルです。 - cov_matrix
共分散行列となるテンソルです。形状は(d, d)
であり、ここでd
は変数の数と同じです。 - mean
平均ベクトルとなるテンソルです。形状は(d,)
であり、ここでd
は変数の数です。
torch.mvnormal()
関数は、以下の処理を実行します。
- 共分散行列の平方根
L
を計算します。
L = torch.linalg.cholesky(cov_matrix)
- 標準正規乱数を生成します。
z = _standard_normal(sample_shape + L.shape)
- 生成した乱数を
L
で変換します。
x = torch.matmul(L, z)
- 変換された乱数を返します。
torch.mvnormal()
関数は、以下の点に注意する必要があります。
sample_shape
はタプルである必要があります。- 共分散行列が正定値であることを確認する必要があります。
torch.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample()
の代替方法として、torch.mvnormal()
関数を使用する方法は、共分散行列が対角行列である場合に有効です。
例
以下の例は、torch.mvnormal()
関数を使用して、5 次元の多変量正規分布から 10 個の乱数を生成する方法を示します。
import torch
# パラメータの設定
mean = torch.zeros(5)
cov_matrix = torch.eye(5)
sample_shape = (10,)
# 乱数の生成
samples = torch.mvnormal(mean, cov_matrix, sample_shape)
print(samples)
tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0476, 0.0068, 0.0628, -0.0798, 0.0299],
[ 0.0436, -0.0993, -0.0429, 0.0117, 0.0624],
[-0.1227, -0.0129, 0.1320, 0.0167, 0.0023],
[-0.0310, -0.0498, -0.0103, -0.0841, -0.0448],
[ 0.1067, 0.0187, 0.0164, -0.0201, -0.0648],
[ 0.0285, 0.0845, 0.0312, 0.0140, 0.0394],
[ 0.0317, -0.0760, -0.0031, 0.0102, 0.0321],
[ 0.0460, 0.0290, -0.0153, -0.0513, 0.0745],
[-0.0105, 0.0446, 0.0031, 0.0318, -0.0318]])
手動でサンプルを生成する
torch.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample()
の代替方法として、手動でサンプルを生成することもできます。この方法は、以下の手順で行うことができます。
- 低ランク共分散行列
Σ
を計算します。