【初心者向け】AffineTransformで確率分布をカンタン変換!サンプルコード付き


torch.distributions.transforms.AffineTransform は、PyTorch の "Probability Distributions" モジュールで提供されるクラスで、確率分布を変換するためのアフィン変換を実装しています。この変換は、入力データに対して線形変換と平行移動を適用し、出力分布を生成します。

主な機能

  • 変換と逆変換を実行する
  • 変換された分布のサンプリングを行う
  • 変換された分布の確率密度関数と累積分布関数を計算する
  • 入力データに対する線形変換と平行移動を適用する

使用方法

AffineTransform クラスは、以下の引数を受け取ります。

  • event_dim: 変換を適用する次元
  • scale: 線形変換行列
  • loc: 平行移動ベクトル

これらの引数を使用して、AffineTransform オブジェクトを作成できます。

import torch
from torch.distributions import transforms

loc = torch.tensor([1.0, 2.0])
scale = torch.tensor([[2.0, 3.0], [4.0, 5.0]])
event_dim=0

affine_transform = transforms.AffineTransform(loc=loc, scale=scale, event_dim=event_dim)

AffineTransform オブジェクトを使用して、確率分布を変換できます。

base_distribution = torch.distributions.Normal(loc=0.0, scale=1.0)
transformed_distribution = affine_transform(base_distribution)

変換された分布の確率密度関数と累積分布関数を計算できます。

x = torch.tensor([-1.0, 0.0, 1.0])
pdf = transformed_distribution.pdf(x)
cdf = transformed_distribution.cdf(x)

変換された分布からサンプルを生成できます。

n_samples = 1000
samples = transformed_distribution.sample(n_samples)

変換と逆変換を実行できます。

y = affine_transform(x)
x_inv = affine_transform.inv(y)

応用例

AffineTransform は、以下のようなさまざまな場面で使用できます。

  • 確率分布の合成: 複数の確率分布を組み合わせる
  • データのシフト: データを一定量だけシフトする
  • データのスケーリング: データのスケールを変更する
  • データの標準化: データを平均 0、標準偏差 1 に変換する

AffineTransform は、確率分布を変換するための強力なツールです。このクラスを理解することで、さまざまなデータ変換タスクを効率的に実行することができます。



import torch
from torch.distributions import Normal, AffineTransform

# 正規分布を定義
base_distribution = Normal(loc=0.0, scale=1.0)

# アフィン変換を定義
loc = torch.tensor([1.0, 2.0])
scale = torch.tensor([[2.0, 3.0], [4.0, 5.0]])
event_dim = 0

affine_transform = AffineTransform(loc=loc, scale=scale, event_dim=event_dim)

# 変換された分布を生成
transformed_distribution = affine_transform(base_distribution)

# サンプルを生成
n_samples = 1000
samples = transformed_distribution.sample(n_samples)

# サンプルを可視化
import matplotlib.pyplot as plt

plt.hist(samples.numpy(), bins=100)
plt.xlabel('x')
plt.ylabel('PDF')
plt.title('Transformed Distribution')
plt.show()

このコードでは、まず Normal クラスを使用して、平均 0、標準偏差 1 の正規分布を定義します。次に、AffineTransform クラスを使用して、平行移動ベクトル loc と線形変換行列 scale を指定してアフィン変換を定義します。event_dim 引数は、変換を適用する次元を指定します。

その後、AffineTransform オブジェクトを使用して、元の分布を変換し、新しい分布を生成します。生成された分布からサンプルを生成し、ヒストグラムを使用して可視化します。

この例は、AffineTransform を使って分布をどのように変換できるかを示す基本的な例です。実際の使用例では、より複雑な変換や、さまざまな種類の分布に対する変換が必要になる場合があります。

以下のコードは、AffineTransform を使って以下の操作を実行する方法を示しています。

  • データのシフト
  • データのスケーリング
  • データの標準化

これらの例は、AffineTransform の柔軟性と、さまざまなデータ変換タスクに適用できることを示しています。



手動の計算

最も基本的な代替手段は、アフィン変換を明示的に計算することです。これは、以下の式で行うことができます。

y = scale @ x + loc

この方法は、単純で理解しやすいという利点がありますが、計算量が多くなる場合があります。また、自動微分をサポートしていないため、勾配計算が必要な場合は不向きです。

torch.nn モジュールの機能

torch.nn モジュールには、線形変換と平行移動を実行するためのさまざまな機能が用意されています。これらの機能を使用すると、AffineTransform クラスよりも簡潔で効率的にコードを記述することができます。

import torch
import torch.nn as nn

# 線形変換層を定義
linear_layer = nn.Linear(in_features=2, out_features=2)

# 平行移動ベクトルを定義
loc = torch.tensor([1.0, 2.0])

# 変換を適用
x = torch.tensor([[-1.0, 0.0], [0.0, 1.0]])
y = linear_layer(x) + loc

この方法は、AffineTransform クラスよりも高速でメモリ効率が良いという利点があります。また、自動微分をサポートしているため、勾配計算が必要な場合にも適しています。

カスタム変換

独自の変換を実装することもできます。これは、複雑な変換や、AffineTransform クラスでサポートされていない機能が必要な場合に役立ちます。

import torch

class MyTransform(torch.distributions.transforms.Transform):
    def __init__(self, c):
        super().__init__()
        self.c = c

    def _call(self, x):
        return x**2 + self.c

    def _inverse(self, y):
        return torch.sqrt(y - self.c)

base_distribution = torch.distributions.Normal(loc=0.0, scale=1.0)
transformed_distribution = MyTransform(c=1.0)(base_distribution)

この方法は、柔軟性が高いという利点がありますが、実装が複雑になる場合があります。また、変換の確率密度関数と累積分布関数、およびその導関数を導出する必要があります。

適切な代替手段は、状況によって異なります。以下の要素を考慮する必要があります。

  • 自動微分
    勾配計算が必要な場合は、torch.nn モジュールの機能が適切です。手動の計算は自動微分をサポートしておらず、カスタム変換は実装次第です。
  • メモリ効率
    メモリ効率が重要な場合は、torch.nn モジュールの機能が適切です。手動の計算とカスタム変換は、状況によって異なります。
  • 計算量
    計算量が多い場合は、torch.nn モジュールの機能が適切です。手動の計算は最も計算量が多くなり、カスタム変換は状況によって異なります。
  • 変換の複雑さ
    変換が単純な場合は、手動の計算または torch.nn モジュールの機能が適切です。複雑な変換の場合は、カスタム変換が必要になる場合があります。