テンサーの精度検証をレベルアップ!PyTorchの`torch.testing.assert_close()`関数と代替方法


動作原理

torch.testing.assert_close() は、以下の式に基づいて2つのテンサー actualexpected を比較します。

|actual - expected| <= atol + rtol * |expected|

ここで、

  • rtol: 相対許容誤差を表します。これは、expected の値の一定割合を許容される最大差として設定します。
  • atol: 絶対許容誤差を表します。これは、expected の値に関係なく、actualexpected の間の許容される最大差を表します。

2つのテンサーが上記の式を満たす場合、assert_close() は成功し、テストは合格となります。一方、式を満たさない場合、assert_close() は失敗し、テストは不合格となります。

主な引数

torch.testing.assert_close() は、以下の引数を受け取ります。

  • check_device: Trueの場合、2つのテンサーが同じデバイス上にあることを確認します(デフォルト: True)
  • check_dtype: Trueの場合、2つのテンサーのデータ型が一致していることを確認します(デフォルト: True)
  • equal_nan: Trueの場合、NaN値も等しいとみなされます(デフォルト: False)
  • rtol: 相対許容誤差(デフォルト: 1e-05
  • atol: 絶対許容誤差(デフォルト: 1e-08
  • expected: 比較対象となる2番目のテンサー
  • actual: 比較対象となる最初のテンサー

以下の例は、torch.testing.assert_close() を使用して2つのテンサーを比較する方法を示しています。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0])
expected = torch.tensor([1.0001, 2.0002, 3.0003])

# 許容誤差を設定
atol = 1e-3
rtol = 1e-4

# テストを実行
tt.assert_close(actual, expected, atol=atol, rtol=rtol)

上記の例では、atolrtol をそれぞれ 1e-31e-4 に設定しています。これは、expected の値の0.1%未満の差であれば許容されることを意味します。

  • torch.testing.assert_close() は、PyTorchのテストにおいて、テンサーの精度を検証するための強力なツールです。許容誤差を適切に設定することで、モデルの訓練や推論の過程で生成された結果の信頼性を確認することができます。
  • torch.testing.assert_close() は、torch.allclose() と似ていますが、いくつかの重要な違いがあります。
    • torch.allclose() は、テンサー内のすべての要素が等しいかどうかを検証します。一方、torch.testing.assert_close() は、許容範囲内で近似的に等しいかどうかを検証します。
    • torch.allclose() は、NaN値を考慮しません。一方、torch.testing.assert_close() は、equal_nan オプションを指定することでNaN値を考慮することができます。


基本的な例

この例では、2つのテンサーを絶対許容誤差と相対許容誤差を使用して比較します。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0])
expected = torch.tensor([1.0001, 2.0002, 3.0003])

# 許容誤差を設定
atol = 1e-3
rtol = 1e-4

# テストを実行
tt.assert_close(actual, expected, atol=atol, rtol=rtol)

絶対許容誤差のみを使用する例

この例では、絶対許容誤差のみを使用して2つのテンサーを比較します。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0])
expected = torch.tensor([1.001, 2.002, 3.003])

# 許容誤差を設定
atol = 1e-2

# テストを実行
tt.assert_close(actual, expected, atol=atol)
import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([100.0, 200.0, 300.0])
expected = torch.tensor([100.1, 200.2, 300.3])

# 許容誤差を設定
rtol = 1e-3

# テストを実行
tt.assert_close(actual, expected, rtol=rtol)

NaN値を考慮する例

この例では、equal_nan オプションを使用して、NaN値を考慮したテンサーの比較を行います。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, float('nan')])
expected = torch.tensor([1.0001, 2.0002, float('nan')])

# テストを実行
tt.assert_close(actual, expected, equal_nan=True)

データ型のチェックを無効にする例

この例では、check_dtype オプションを使用して、データ型のチェックを無効にします。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
expected = torch.tensor([1.0001, 2.0002, 3.0003], dtype=torch.float64)

# テストを実行
tt.assert_close(actual, expected, check_dtype=False)

この例では、check_device オプションを使用して、デバイスのチェックを無効にします。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0, 2.0, 3.0], device='cpu')
expected = torch.tensor([1.0001, 2.0002, 3.0003], device='cuda')

# テストを実行
tt.assert_close(actual, expected, check_device=False)

これらの例は、torch.testing.assert_close() の様々な使用方法を示しています。具体的な状況に合わせて、適切なオプションを設定して使用してください。

  • torch.testing.assert_close() は、PyTorchのテストにおいて、テンサーの精度を検証するための強力なツールです。許容誤差を適切に設定することで、モデルの訓練や推論の過程で生成された結果の信頼性を確認することができます。


torch.allclose()

torch.allclose() は、torch.testing.assert_close() と似ていますが、以下の点が異なります。

  • NaN値を考慮しません。
  • デフォルトの許容誤差は atol=1e-09rtol=1e-5 です。
  • 絶対許容誤差と相対許容誤差の代わりに、許容誤差の閾値のみを指定します。

以下の例は、torch.allclose() を使用して2つのテンサーを比較する方法を示しています。

import torch
import torch.testing as tt

# テスト対象のテンサーを作成
actual = torch.tensor([1.0001, 2.0002, 3.0003])
expected = torch.tensor([1.0, 2.0, 3.0])

# 許容誤差を設定
rtol = 1e-3

# テストを実行
tt.allclose(actual, expected, rtol=rtol)

カスタム断言関数

許容誤差の計算ロジックや、NaN値の扱いなどを細かく制御したい場合は、カスタム断言関数を作成することもできます。

以下の例は、カスタム断言関数を使用して2つのテンサーを比較する方法を示しています。

import torch

def my_assert_close(actual, expected, atol=1e-3, rtol=1e-5, equal_nan=False):
    # 絶対許容誤差と相対許容誤差に基づいて差を計算
    diff = actual - expected
    abs_diff = torch.abs(diff)
    rel_diff = abs_diff / torch.abs(expected)

    # 許容範囲内に収まっているかどうかを確認
    if (diff <= atol) or (rel_diff <= rtol):
        return True
    else:
        if equal_nan and torch.isnan(diff).any():
            return True
        else:
            return False

# テスト対象のテンサーを作成
actual = torch.tensor([1.0001, 2.0002, 3.0003])
expected = torch.tensor([1.0, 2.0, 3.0])

# テストを実行
my_assert_close(actual, expected)

NumPyやSciPyなどの他のライブラリも、テンサーの比較に使用することができます。

  • SciPy: scipy.allclose()
  • NumPy: np.allclose()

これらのライブラリの関数は、PyTorchの torch.testing.assert_close() とは異なるオプションや機能を提供している場合があります。

選択の指針

  • 特定のライブラリとの整合性を保ちたい: NumPyやSciPy
  • 許容誤差の計算ロジックを制御したい: カスタム断言関数
  • シンプルで使いやすい: torch.testing.assert_close()
  • 他のライブラリの関数は、PyTorchのテンサーとは異なる形式でデータを受け入れる場合があるため、注意が必要です。
  • カスタム断言関数を作成する場合は、テストのロジックが明確で分かりやすいように記述することが重要です。
  • torch.allclose()torch.testing.assert_close() と完全に互換性があるわけではありません。