確率分布ライブラリで実現する低ランク多変量正規分布からの乱数生成:`torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample()` のしくみ


  • validate_args
    入力パラメータの検証を行うかどうかを指定するブーリアン値です。デフォルトは False です。
  • sample_shape
    生成する乱数のバッチサイズを表すタプルです。
  • cov_diag
    共分散行列の対角成分となるテンソルです。形状は (d,) であり、d は変数の数と同じです。
  • cov_factor
    低ランク共分散行列の因子となるテンソルです。形状は (d, r) であり、ここで d は変数の数、r はランクを表します。

この関数は、以下の処理を実行します。

  1. 低ランク共分散行列 Σ を計算します。
Sigma = torch.matmul(cov_factor, cov_factor.t()) + torch.diag(cov_diag)
  1. Cholesky分解を用いて、Σ の平方根 L を計算します。
L = torch.linalg.cholesky(Sigma)
  1. 標準正規乱数を生成します。
z = _standard_normal(sample_shape + L.shape)
  1. 生成した乱数を L で変換します。
x = torch.matmul(L, z)
  1. 変換された乱数を返します。

この関数は、以下の点に注意する必要があります。

  • validate_argsTrue の場合、入力パラメータの検証が行われます。
  • sample_shape はタプルである必要があります。
  • cov_factorcov_diag の形状が正しいことを確認する必要があります。

以下の例は、torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample() 関数を使用して、5 次元の多変量正規分布から 10 個の乱数を生成する方法を示します。

import torch
from torch.distributions import LowRankMultivariateNormal

# パラメータの設定
cov_factor = torch.randn(5, 2)
cov_diag = torch.rand(5)
sample_shape = (10,)

# 低ランク多変量正規分布の作成
mvnormal = LowRankMultivariateNormal(cov_factor, cov_diag)

# 乱数の生成
samples = mvnormal.rsample(sample_shape)

print(samples)

このコードを実行すると、以下のような出力が得られます。

tensor([[ 0.2360,  0.4184,  0.1321, -0.0894,  0.0039],
        [ 0.7492, -1.0118,  0.7053,  0.2052,  0.0145],
        [-0.1587, -0.3229, -0.5249, -0.2930,  0.5050],
        [ 0.0951, -0.0542,  0.2281,  0.4918, -0.2493],
        [-0.5043, -0.1575,  0.1222,  0.2300, -0.1195],
        [ 0.4154,  0.5811, -0.1560, -0.0463, -0.2498],
        [ 0.2079, -0.0478,  0.6974,  0.2070, -0.2007],
        [-0.3049, -0.0752, -0.2854, -0.0584,  0.1472],
        [ 0.5407,  0.0299,  0.0974,  0.4915, -0.


import torch
from torch.distributions import LowRankMultivariateNormal

# Define the parameters of the low-rank multivariate normal distribution
cov_factor = torch.randn(5, 2)  # Covariance factor matrix
cov_diag = torch.rand(5)  # Diagonal elements of the covariance matrix
sample_shape = (10,)  # Batch size of the samples

# Create the low-rank multivariate normal distribution
mvnormal = LowRankMultivariateNormal(cov_factor, cov_diag)

# Generate random samples
samples = mvnormal.rsample(sample_shape)

# Print the samples
print(samples)

This code will generate 10 random samples from a 5-dimensional multivariate normal distribution with a low-rank covariance matrix. The covariance matrix is defined by the cov_factor and cov_diag parameters. The sample_shape parameter specifies the batch size of the samples.

Here is a breakdown of the code:

    • torch: The main PyTorch library.
    • torch.distributions: The PyTorch distributions module, which contains the LowRankMultivariateNormal class.
  1. Define the parameters of the low-rank multivariate normal distribution

    • cov_factor: A tensor of shape (d, r) where d is the number of variables and r is the rank of the covariance matrix. The cov_factor matrix defines the low-rank part of the covariance matrix.
    • cov_diag: A tensor of shape (d,) where d is the number of variables. The cov_diag tensor defines the diagonal elements of the covariance matrix.
    • sample_shape: A tuple of integers specifying the batch size of the samples.
  2. Create the low-rank multivariate normal distribution

    • An instance of the LowRankMultivariateNormal class is created using the cov_factor, cov_diag, and sample_shape parameters.
  3. Generate random samples

    • The rsample() method of the mvnormal object is called to generate random samples from the low-rank multivariate normal distribution. The samples are returned as a tensor of shape (sample_shape, d).
  4. Print the samples

    • The samples tensor is printed to the console.


torch.mvnormal() 関数を使用する

torch.mvnormal() 関数は、標準の多変量正規分布から乱数を生成するために使用できます。この関数は、以下のパラメータを受け取ります。

  • sample_shape
    生成する乱数のバッチサイズを表すタプルです。
  • cov_matrix
    共分散行列となるテンソルです。形状は (d, d) であり、ここで d は変数の数と同じです。
  • mean
    平均ベクトルとなるテンソルです。形状は (d,) であり、ここで d は変数の数です。

torch.mvnormal() 関数は、以下の処理を実行します。

  1. 共分散行列の平方根 L を計算します。
L = torch.linalg.cholesky(cov_matrix)
  1. 標準正規乱数を生成します。
z = _standard_normal(sample_shape + L.shape)
  1. 生成した乱数を L で変換します。
x = torch.matmul(L, z)
  1. 変換された乱数を返します。

torch.mvnormal() 関数は、以下の点に注意する必要があります。

  • sample_shape はタプルである必要があります。
  • 共分散行列が正定値であることを確認する必要があります。

torch.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample() の代替方法として、torch.mvnormal() 関数を使用する方法は、共分散行列が対角行列である場合に有効です。

以下の例は、torch.mvnormal() 関数を使用して、5 次元の多変量正規分布から 10 個の乱数を生成する方法を示します。

import torch

# パラメータの設定
mean = torch.zeros(5)
cov_matrix = torch.eye(5)
sample_shape = (10,)

# 乱数の生成
samples = torch.mvnormal(mean, cov_matrix, sample_shape)

print(samples)
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0476,  0.0068,  0.0628, -0.0798,  0.0299],
        [ 0.0436, -0.0993, -0.0429,  0.0117,  0.0624],
        [-0.1227, -0.0129,  0.1320,  0.0167,  0.0023],
        [-0.0310, -0.0498, -0.0103, -0.0841, -0.0448],
        [ 0.1067,  0.0187,  0.0164, -0.0201, -0.0648],
        [ 0.0285,  0.0845,  0.0312,  0.0140,  0.0394],
        [ 0.0317, -0.0760, -0.0031,  0.0102,  0.0321],
        [ 0.0460,  0.0290, -0.0153, -0.0513,  0.0745],
        [-0.0105,  0.0446,  0.0031,  0.0318, -0.0318]])

手動でサンプルを生成する

torch.lowrank_multivariate_normal.LowRankMultivariateNormal.rsample() の代替方法として、手動でサンプルを生成することもできます。この方法は、以下の手順で行うことができます。

  1. 低ランク共分散行列 Σ を計算します。