確率分布の柔軟性を高める:PyTorch Probability DistributionsのTransformedDistributionとarg_constraints


理解を深めるためのポイント

  1. TransformedDistribution クラス

    TransformedDistribution クラスは、基底となる分布と変換関数を組み合わせて、新しい分布を定義するためのクラスです。このクラスは、確率分布の柔軟性を高め、複雑な分布を表現するために役立ちます。

  2. パラメータ制約

    確率分布のパラメータには、有効な値の範囲を定義する制約条件がしばしば存在します。これらの制約条件は、分布のサンプリングや確率計算の精度と安定性を保証するために重要です。

  3. 辞書形式

    arg_constraints 属性は、辞書形式でパラメータ制約条件を定義します。辞書のキーはパラメータ名、値は対応する制約条件オブジェクトを表します。

詳細な説明

arg_constraints 辞書には、以下のキーと値が格納されます。


  • 制約条件オブジェクト (例: constraints.positive, constraints.real)
  • キー
    パラメータ名 (例: "loc", "scale")

制約条件オブジェクトは、torch.distributions.constraints モジュールで定義されているクラスです。これらのクラスは、パラメータ値の有効性を検証するためのメソッドを提供します。

以下の例は、Gumbel 分布の arg_constraints 辞書を示しています。

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}

この例では、loc パラメータは実数である必要があり、scale パラメータは正数である必要があります。

arg_constraints 属性は、以下の理由で重要です。

  • 柔軟性
    さまざまな制約条件を定義することで、複雑な分布を表現することができます。
  • ドキュメント
    パラメータの制約条件を明確に定義し、コードの理解しやすさを向上させます。
  • エラー防止
    無効なパラメータ値によるエラーを防ぎ、プログラムの安定性を向上させます。


import torch
from torch.distributions import constraints
from torch.distributions import transformed_distribution
from torch.distributions.gumbel import GumbelDistribution

base_distribution = torch.distributions.Normal(loc=0.0, scale=1.0)
transform = torch.distributions.AffineTransform(loc=1.0, scale=2.0)

transformed_distribution = transformed_distribution.TransformedDistribution(
    base_distribution=base_distribution,
    transform=transform,
)

arg_constraints = transformed_distribution.arg_constraints
print(arg_constraints)  # {'loc': constraints.real, 'scale': constraints.positive}

# サンプルを生成
samples = transformed_distribution.sample((10,))
print(samples)
  1. Normal 分布を基底分布として定義します。
  2. AffineTransform 変換関数を定義します。
  3. TransformedDistribution クラスを使用して、変換された分布を定義します。
  4. arg_constraints 属性を使用して、パラメータ制約条件を印刷します。
  5. サンプリングを行い、結果を印刷します。


カスタムバリデーション関数

arg_constraints のような辞書による制約定義は使わず、カスタムのバリデーション関数を作成する方法です。この関数は、分布のパラメータを受け取り、有効かどうかを判定します。

import torch

def validate_args(loc, scale):
  if not torch.all(scale > 0):
    raise ValueError("scale parameter must be positive")

base_distribution = torch.distributions.Normal(loc=0.0, scale=1.0)
transform = torch.distributions.AffineTransform(loc=1.0, scale=2.0)

transformed_distribution = transformed_distribution.TransformedDistribution(
    base_distribution=base_distribution,
    transform=transform,
    validate_args=validate_args  # カスタムバリデーション関数を渡す
)

# サンプル生成 (バリデーションは内部で行われる)
samples = transformed_distribution.sample((10,))
print(samples)

この方法では、柔軟性が高く、複雑な制約条件にも対応できますが、コードが冗長になる可能性があります。

サブクラス化

TransformedDistribution を継承したサブクラスを作成し、バリデーションロジックをそのサブクラス内に実装する方法です。

import torch
from torch.distributions import transformed_distribution
from torch.distributions.normal import Normal
from torch.distributions.constraints import Positive

class MyTransformedDistribution(transformed_distribution.TransformedDistribution):
  def __init__(self, loc, scale, transform):
    super().__init__(base_distribution=Normal(loc, scale), transform=transform)
    self.register_buffer("scale", torch.clamp(scale, min=0.01))  # 内部でスケールを下限保証

  def _validate_args(self):
    # 何もしない (制約は内部で行われている)

transformed_distribution = MyTransformedDistribution(loc=0.0, scale=1.0, transform=torch.distributions.AffineTransform(loc=1.0, scale=2.0))

# サンプル生成
samples = transformed_distribution.sample((10,))
print(samples)

この方法では、コードの可読性が高まり、再利用性も向上しますが、専用のサブクラスを作成する必要があるため、オーバーヘッドが増加します。

arg_constraints が適切な選択肢

一般的には、arg_constraints 属性が最もシンプルで推奨される方法です。可読性が高く、PyTorch が提供する制約条件オブジェクトを活用できます。