【初心者向け】PyTorchの`Tensor.allclose`関数: 詳細解説とサンプルコード


関数詳細

torch.allclose(input, other, atol=1e-8, rtol=1e-5, equal_nan=False)
  • equal_nan: True の場合、NaN は等しいとみなされます (デフォルト: False)。
  • rtol: 相対許容誤差 (デフォルト: 1e-5)。2つの値の絶対差がこの値 * max(abs(input), abs(other)) より小さい場合、それらは等しいとみなされます。
  • atol: 絶対許容誤差 (デフォルト: 1e-8)。2つの値の差がこの値より小さい場合、それらは等しいとみなされます。
  • other: 比較する2番目の Tensor
  • input: 比較する最初の Tensor

戻り値

  • Tensor: 2つの Tensor が要素ごとにほぼ等しいかどうかを示すブール型 Tensor
import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])

# 絶対許容誤差を 1e-7 に設定して比較
print(torch.allclose(a, b, atol=1e-7))
# True

# 相対許容誤差を 1e-4 に設定して比較
print(torch.allclose(a, b, rtol=1e-4))
# True

# NaN を等しいとみなして比較
c = torch.tensor([1.0, 2.0, float('nan')])
print(torch.allclose(a, c, equal_nan=True))
# True
  • equal_nan パラメータは、NaN の扱いを制御するために使用されます。
  • 比較精度を調整するには、atolrtol パラメータを適切な値に設定する必要があります。
  • torch.allclose 関数は、2つの Tensor の形状とデータ型が一致していることを前提としています。


絶対許容誤差と相対許容誤差

この例では、絶対許容誤差と相対許容誤差を使用して、2つの Tensor がどれだけ近いのかを調べます。

import torch

a = torch.tensor([1.0, 2.0, 1e11])
b = a + torch.tensor([1e-8, 1e-7, 1e-4])

# 絶対許容誤差を 1e-9 に設定
print(torch.allclose(a, b, atol=1e-9))
# False

# 相対許容誤差を 1e-3 に設定
print(torch.allclose(a, b, rtol=1e-3))
# True

この例では、atol を小さくすると精度が上がり、rtol を小さくすると許容される誤差が小さくなります。

NaN の扱い

この例では、equal_nan パラメータを使用して、NaN を比較する方法を示します。

import torch

a = torch.tensor([1.0, 2.0, float('nan')])
b = torch.tensor([1.0, 2.0, float('nan')])

# NaN を等しいとみなさない場合
print(torch.allclose(a, b))
# False

# NaN を等しいとみなす場合
print(torch.allclose(a, b, equal_nan=True))
# True

この例では、equal_nanTrue に設定すると、NaN は等しいとみなされます。

この例では、torch.allclose 関数が2つの Tensor の形状とデータ型が一致していることを前提としていることを示します。

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)

# データ型が異なる場合
print(torch.allclose(a, b))
# RuntimeError: Can only compare Tensors of same dtype. Got: float32, float64

# 形状が異なる場合
b = torch.tensor([1.0, 2.0, 3.0, 4.0])
print(torch.allclose(a, b))
# RuntimeError: Can only compare Tensors of same size. Got: (3,), (4,)

この例では、2つの Tensor の形状とデータ型が一致していない場合、エラーが発生します。



逐次比較

最も基本的な方法は、ループを使用して2つの Tensor の要素を個別に比較することです。

import torch

def is_close(a, b, atol=1e-8, rtol=1e-5):
    diff = abs(a - b)
    epsilon = max(abs(a), abs(b)) * rtol
    return (diff < atol) | (diff < epsilon)

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])

allclose = True
for i in range(len(a)):
    if not is_close(a[i], b[i]):
        allclose = False
        break

print(allclose)
# True

長所

  • メモリ使用量が少ない
  • シンプルで分かりやすい

短所

  • 複雑な比較には向かない
  • 遅い

torch.norm 関数を使用する

torch.norm 関数を使用して、2つの Tensor の間の距離を計算し、許容範囲内かどうかを判断することができます。

import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])

threshold = atol + rtol * max(torch.norm(a), torch.norm(b))
distance = torch.norm(a - b)

allclose = distance < threshold
print(allclose)
# True

長所

  • 逐次比較よりも高速

短所

  • 許容範囲を計算するのが難しい
  • torch.norm 関数の計算コストが高い

カスタム比較関数を使用する

特定のニーズに合わせて、独自の比較関数を定義することができます。

import torch

def custom_compare(a, b):
    # 独自の比較ロジックを実装
    # ...
    return result

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])

allclose = all(custom_compare(a[i], b[i]) for i in range(len(a)))
print(allclose)
# True

長所

  • 複雑な比較に対応できる
  • 柔軟性が高い
  • 時間がかかる場合がある
  • 実装が難しい