【初心者向け】PyTorchの`Tensor.allclose`関数: 詳細解説とサンプルコード
関数詳細
torch.allclose(input, other, atol=1e-8, rtol=1e-5, equal_nan=False)
equal_nan
:True
の場合、NaN
は等しいとみなされます (デフォルト:False
)。rtol
: 相対許容誤差 (デフォルト:1e-5
)。2つの値の絶対差がこの値 *max(abs(input), abs(other))
より小さい場合、それらは等しいとみなされます。atol
: 絶対許容誤差 (デフォルト:1e-8
)。2つの値の差がこの値より小さい場合、それらは等しいとみなされます。other
: 比較する2番目のTensor
input
: 比較する最初のTensor
戻り値
Tensor
: 2つのTensor
が要素ごとにほぼ等しいかどうかを示すブール型Tensor
。
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])
# 絶対許容誤差を 1e-7 に設定して比較
print(torch.allclose(a, b, atol=1e-7))
# True
# 相対許容誤差を 1e-4 に設定して比較
print(torch.allclose(a, b, rtol=1e-4))
# True
# NaN を等しいとみなして比較
c = torch.tensor([1.0, 2.0, float('nan')])
print(torch.allclose(a, c, equal_nan=True))
# True
equal_nan
パラメータは、NaN
の扱いを制御するために使用されます。- 比較精度を調整するには、
atol
とrtol
パラメータを適切な値に設定する必要があります。 torch.allclose
関数は、2つのTensor
の形状とデータ型が一致していることを前提としています。
絶対許容誤差と相対許容誤差
この例では、絶対許容誤差と相対許容誤差を使用して、2つの Tensor
がどれだけ近いのかを調べます。
import torch
a = torch.tensor([1.0, 2.0, 1e11])
b = a + torch.tensor([1e-8, 1e-7, 1e-4])
# 絶対許容誤差を 1e-9 に設定
print(torch.allclose(a, b, atol=1e-9))
# False
# 相対許容誤差を 1e-3 に設定
print(torch.allclose(a, b, rtol=1e-3))
# True
この例では、atol
を小さくすると精度が上がり、rtol
を小さくすると許容される誤差が小さくなります。
NaN
の扱い
この例では、equal_nan
パラメータを使用して、NaN
を比較する方法を示します。
import torch
a = torch.tensor([1.0, 2.0, float('nan')])
b = torch.tensor([1.0, 2.0, float('nan')])
# NaN を等しいとみなさない場合
print(torch.allclose(a, b))
# False
# NaN を等しいとみなす場合
print(torch.allclose(a, b, equal_nan=True))
# True
この例では、equal_nan
を True
に設定すると、NaN
は等しいとみなされます。
この例では、torch.allclose
関数が2つの Tensor
の形状とデータ型が一致していることを前提としていることを示します。
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)
# データ型が異なる場合
print(torch.allclose(a, b))
# RuntimeError: Can only compare Tensors of same dtype. Got: float32, float64
# 形状が異なる場合
b = torch.tensor([1.0, 2.0, 3.0, 4.0])
print(torch.allclose(a, b))
# RuntimeError: Can only compare Tensors of same size. Got: (3,), (4,)
この例では、2つの Tensor
の形状とデータ型が一致していない場合、エラーが発生します。
逐次比較
最も基本的な方法は、ループを使用して2つの Tensor
の要素を個別に比較することです。
import torch
def is_close(a, b, atol=1e-8, rtol=1e-5):
diff = abs(a - b)
epsilon = max(abs(a), abs(b)) * rtol
return (diff < atol) | (diff < epsilon)
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])
allclose = True
for i in range(len(a)):
if not is_close(a[i], b[i]):
allclose = False
break
print(allclose)
# True
長所
- メモリ使用量が少ない
- シンプルで分かりやすい
短所
- 複雑な比較には向かない
- 遅い
torch.norm 関数を使用する
torch.norm
関数を使用して、2つの Tensor
の間の距離を計算し、許容範囲内かどうかを判断することができます。
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])
threshold = atol + rtol * max(torch.norm(a), torch.norm(b))
distance = torch.norm(a - b)
allclose = distance < threshold
print(allclose)
# True
長所
- 逐次比較よりも高速
短所
- 許容範囲を計算するのが難しい
torch.norm
関数の計算コストが高い
カスタム比較関数を使用する
特定のニーズに合わせて、独自の比較関数を定義することができます。
import torch
def custom_compare(a, b):
# 独自の比較ロジックを実装
# ...
return result
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([1.00000001, 2.0, 3.0])
allclose = all(custom_compare(a[i], b[i]) for i in range(len(a)))
print(allclose)
# True
長所
- 複雑な比較に対応できる
- 柔軟性が高い
- 時間がかかる場合がある
- 実装が難しい