PyTorch Tensor の高度な比較テクニック:『torch.Tensor.isclose』を超えたアプローチ
torch.Tensor.isclose(other, atol=1e-08, rtol=1e-05, equal_nan=False)
equal_nan
:NaN
を比較対象に含めるかどうか (デフォルト:False
)rtol
: 相対誤差許容値 (デフォルト:1e-05
)atol
: 絶対誤差許容値 (デフォルト:1e-08
)other
: 比較対象となるもう一つのTensor
動作原理
torch.Tensor.isclose
は、以下の条件を満たす場合に True
を返します。
- 2つの
Tensor
の形状が一致する。 - 対応する要素同士が、絶対誤差
atol
または 相対誤差rtol
のいずれか一方以下の差で近似している。- 絶対誤差比較:
abs(input[i] - other[i]) <= atol
- 相対誤差比較:
abs(input[i] - other[i]) / max(abs(input[i]), abs(other[i])) <= rtol
- 絶対誤差比較:
- オプション
equal_nan
がTrue
の場合、NaN
同士も等しいとみなされます。
# テストデータ準備
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.00001, 2.0001, 3.00000001])
atol = 1e-5
rtol = 1e-4
# 絶対誤差と相対誤差による比較
result1 = x.isclose(y, atol=atol)
result2 = x.isclose(y, rtol=rtol)
print(result1) # tensor([ True True True])
print(result2) # tensor([ True True False])
# NaN を含む場合
x = torch.tensor([1.0, 2.0, float('nan')])
y = torch.tensor([1.0, 2.0, float('nan')])
result3 = x.isclose(y, equal_nan=True)
print(result3) # tensor([ True True True])
torch.Tensor.isclose
は、数値計算における誤差の影響を考慮した比較に役立ちますが、完全に同値かどうかを判定するものではありません。- 許容誤差
atol
とrtol
は、問題のスケールに合わせて適切な値を設定する必要があります。 torch.allclose
関数は、torch.Tensor.isclose
と似ていますが、要素ごとの比較に加え、テンソル全体の形状とサイズも比較します。
絶対誤差と相対誤差による比較
import torch
# テストデータ準備
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0001, 3.00000001])
# 絶対誤差による比較
atol = 1e-5
result1 = x.isclose(y, atol=atol)
print(f"絶対誤差 {atol} で比較: {result1}")
# 相対誤差による比較
rtol = 1e-4
result2 = x.isclose(y, rtol=rtol)
print(f"相対誤差 {rtol} で比較: {result2}")
出力
絶対誤差 1e-05 で比較: tensor([False False True])
相対誤差 1e-04 で比較: tensor([ True True False])
この例では、atol
を小さく設定すると、より厳密な比較となり、x
と y
は絶対誤差 1e-5
未満でしか一致しなくなります。一方、rtol
を大きく設定すると、x
と y
の値の相対的な違いを許容し、より多くの要素が一致するようになります。
NaN を含む場合
import torch
# テストデータ準備
x = torch.tensor([1.0, 2.0, float('nan')])
y = torch.tensor([1.0, 2.0, float('nan')])
# `equal_nan=True` を指定して比較
result = x.isclose(y, equal_nan=True)
print(f"NaN を含む比較: {result}")
出力
NaN を含む比較: tensor([ True True True])
この例では、equal_nan=True
オプションを指定することで、NaN
同士も等しいとみなされ、すべての要素が一致すると判定されます。
import torch
# テストデータ準備
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[1.0001, 2.0001, 3.00000001], [4, 5, 6]])
# `torch.Tensor.isclose` を使用して比較
atol = 1e-5
result1 = x.isclose(y, atol=atol)
print(f"torch.Tensor.isclose: {result1}")
# `torch.allclose` を使用して比較
rtol = 1e-4
result2 = torch.allclose(x, y, rtol=rtol)
print(f"torch.allclose: {result2}")
torch.Tensor.isclose: tensor([[ True True True]
[ True True True]])
torch.allclose: tensor(True)
絶対誤差と相対誤差の個別比較
- 欠点:
- コードが冗長になる可能性がある。
torch.Tensor.isclose
のような便利な機能がない。
- 利点:
- シンプルでわかりやすい。
- 許容誤差を柔軟に設定できる。
import torch
def is_close(x: torch.Tensor, y: torch.Tensor, atol: float, rtol: float, equal_nan: bool = False) -> torch.Tensor:
"""
2つの Tensor を比較し、絶対誤差と相対誤差の許容範囲内で近似しているかどうかを判定します。
Args:
x (torch.Tensor): 比較対象の Tensor 1。
y (torch.Tensor): 比較対象の Tensor 2。
atol (float): 絶対誤差許容値。
rtol (float): 相対誤差許容値。
equal_nan (bool): NaN を比較対象に含めるかどうか (デフォルト: False)。
Returns:
torch.Tensor: 各要素ごとの比較結果 (True または False)。
"""
abs_diff = torch.abs(x - y)
rel_diff = abs_diff / torch.max(torch.abs(x), torch.abs(y))
result = (abs_diff <= atol) & (rel_diff <= rtol)
if equal_nan:
result = result | (torch.isnan(x) & torch.isnan(y))
return result
# テスト
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0001, 3.00000001])
atol = 1e-5
rtol = 1e-4
result = is_close(x, y, atol, rtol)
print(result)
numpy.allclose を利用する
- 欠点:
- PyTorch 環境でのみ利用可能ではない。
- NumPy をインストールする必要がある。
- 利点:
- NumPy を利用している場合、既存コードに組み込みやすい。
equal_nan
オプションを含むなど、torch.Tensor.isclose
と同様の機能を備えている。
import torch
import numpy as np
# テストデータ準備
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0001, 3.00000001])
# NumPy に変換
x_numpy = x.numpy()
y_numpy = y.numpy()
# NumPy の `allclose` を使用して比較
atol = 1e-5
rtol = 1e-4
result = np.allclose(x_numpy, y_numpy, atol=atol, rtol=rtol, equal_nan=True)
print(result)
- 欠点:
- 開発とテストに時間がかかる。
- コードが複雑になり、保守が難しくなる可能性がある。
- 利点:
- 具体的なニーズに合わせた高度な比較ロジックを実装できる。
- 許容誤差の計算方法や、
NaN
の扱いを自由に定義できる。
import torch
def custom_is_close(x: torch.Tensor, y: torch.Tensor, threshold: float) -> torch.Tensor:
"""
2つの Tensor を比較し、要素ごとの差が指定された閾値以下かどうかを判定します。
Args:
x (torch.Tensor): 比較対象の Tensor 1。
y (torch.Tensor): 比較対象の Tensor 2。
threshold (float): 許容誤差閾値。
Returns:
torch.Tensor: 各要素ごとの比較結果 (True または False)。
"""
diff = torch.abs(