PyTorch Tensor の高度な比較テクニック:『torch.Tensor.isclose』を超えたアプローチ


torch.Tensor.isclose(other, atol=1e-08, rtol=1e-05, equal_nan=False)
  • equal_nan: NaN を比較対象に含めるかどうか (デフォルト: False)
  • rtol: 相対誤差許容値 (デフォルト: 1e-05)
  • atol: 絶対誤差許容値 (デフォルト: 1e-08)
  • other: 比較対象となるもう一つの Tensor

動作原理

torch.Tensor.isclose は、以下の条件を満たす場合に True を返します。

  1. 2つの Tensor の形状が一致する。
  2. 対応する要素同士が、絶対誤差 atol または 相対誤差 rtol のいずれか一方以下の差で近似している。
    • 絶対誤差比較: abs(input[i] - other[i]) <= atol
    • 相対誤差比較: abs(input[i] - other[i]) / max(abs(input[i]), abs(other[i])) <= rtol
  3. オプション equal_nanTrue の場合、NaN同士も等しいとみなされます。
# テストデータ準備
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.00001, 2.0001, 3.00000001])
atol = 1e-5
rtol = 1e-4

# 絶対誤差と相対誤差による比較
result1 = x.isclose(y, atol=atol)
result2 = x.isclose(y, rtol=rtol)
print(result1)  # tensor([ True  True  True])
print(result2)  # tensor([ True  True  False])

# NaN を含む場合
x = torch.tensor([1.0, 2.0, float('nan')])
y = torch.tensor([1.0, 2.0, float('nan')])
result3 = x.isclose(y, equal_nan=True)
print(result3)  # tensor([ True  True  True])
  • torch.Tensor.isclose は、数値計算における誤差の影響を考慮した比較に役立ちますが、完全に同値かどうかを判定するものではありません。
  • 許容誤差 atolrtol は、問題のスケールに合わせて適切な値を設定する必要があります。
  • torch.allclose 関数は、torch.Tensor.isclose と似ていますが、要素ごとの比較に加え、テンソル全体の形状とサイズも比較します。


絶対誤差と相対誤差による比較

import torch

# テストデータ準備
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0001, 3.00000001])

# 絶対誤差による比較
atol = 1e-5
result1 = x.isclose(y, atol=atol)
print(f"絶対誤差 {atol} で比較: {result1}")

# 相対誤差による比較
rtol = 1e-4
result2 = x.isclose(y, rtol=rtol)
print(f"相対誤差 {rtol} で比較: {result2}")

出力

絶対誤差 1e-05 で比較: tensor([False False  True])
相対誤差 1e-04 で比較: tensor([ True  True  False])

この例では、atol を小さく設定すると、より厳密な比較となり、xy は絶対誤差 1e-5 未満でしか一致しなくなります。一方、rtol を大きく設定すると、xy の値の相対的な違いを許容し、より多くの要素が一致するようになります。

NaN を含む場合

import torch

# テストデータ準備
x = torch.tensor([1.0, 2.0, float('nan')])
y = torch.tensor([1.0, 2.0, float('nan')])

# `equal_nan=True` を指定して比較
result = x.isclose(y, equal_nan=True)
print(f"NaN を含む比較: {result}")

出力

NaN を含む比較: tensor([ True  True  True])

この例では、equal_nan=True オプションを指定することで、NaN同士も等しいとみなされ、すべての要素が一致すると判定されます。

import torch

# テストデータ準備
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[1.0001, 2.0001, 3.00000001], [4, 5, 6]])

# `torch.Tensor.isclose` を使用して比較
atol = 1e-5
result1 = x.isclose(y, atol=atol)
print(f"torch.Tensor.isclose: {result1}")

# `torch.allclose` を使用して比較
rtol = 1e-4
result2 = torch.allclose(x, y, rtol=rtol)
print(f"torch.allclose: {result2}")
torch.Tensor.isclose: tensor([[ True  True  True]
                           [ True  True  True]])
torch.allclose: tensor(True)


絶対誤差と相対誤差の個別比較

  • 欠点:
    • コードが冗長になる可能性がある。
    • torch.Tensor.isclose のような便利な機能がない。
  • 利点:
    • シンプルでわかりやすい。
    • 許容誤差を柔軟に設定できる。
import torch

def is_close(x: torch.Tensor, y: torch.Tensor, atol: float, rtol: float, equal_nan: bool = False) -> torch.Tensor:
    """
    2つの Tensor を比較し、絶対誤差と相対誤差の許容範囲内で近似しているかどうかを判定します。

    Args:
        x (torch.Tensor): 比較対象の Tensor 1。
        y (torch.Tensor): 比較対象の Tensor 2。
        atol (float): 絶対誤差許容値。
        rtol (float): 相対誤差許容値。
        equal_nan (bool): NaN を比較対象に含めるかどうか (デフォルト: False)。

    Returns:
        torch.Tensor: 各要素ごとの比較結果 (True または False)。
    """
    abs_diff = torch.abs(x - y)
    rel_diff = abs_diff / torch.max(torch.abs(x), torch.abs(y))
    result = (abs_diff <= atol) & (rel_diff <= rtol)
    if equal_nan:
        result = result | (torch.isnan(x) & torch.isnan(y))
    return result

# テスト
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0001, 3.00000001])
atol = 1e-5
rtol = 1e-4
result = is_close(x, y, atol, rtol)
print(result)

numpy.allclose を利用する

  • 欠点:
    • PyTorch 環境でのみ利用可能ではない。
    • NumPy をインストールする必要がある。
  • 利点:
    • NumPy を利用している場合、既存コードに組み込みやすい。
    • equal_nan オプションを含むなど、torch.Tensor.isclose と同様の機能を備えている。
import torch
import numpy as np

# テストデータ準備
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0001, 3.00000001])

# NumPy に変換
x_numpy = x.numpy()
y_numpy = y.numpy()

# NumPy の `allclose` を使用して比較
atol = 1e-5
rtol = 1e-4
result = np.allclose(x_numpy, y_numpy, atol=atol, rtol=rtol, equal_nan=True)
print(result)
  • 欠点:
    • 開発とテストに時間がかかる。
    • コードが複雑になり、保守が難しくなる可能性がある。
  • 利点:
    • 具体的なニーズに合わせた高度な比較ロジックを実装できる。
    • 許容誤差の計算方法や、NaN の扱いを自由に定義できる。
import torch

def custom_is_close(x: torch.Tensor, y: torch.Tensor, threshold: float) -> torch.Tensor:
    """
    2つの Tensor を比較し、要素ごとの差が指定された閾値以下かどうかを判定します。

    Args:
        x (torch.Tensor): 比較対象の Tensor 1。
        y (torch.Tensor): 比較対象の Tensor 2。
        threshold (float): 許容誤差閾値。

    Returns:
        torch.Tensor: 各要素ごとの比較結果 (True または False)。
    """
    diff = torch.abs(