PyTorchで乱数生成の再現性を担保する方法:torch.random.get_rng_state() 解説とサンプルコード


torch.random.get_rng_state() 関数は、PyTorchにおける乱数生成器の状態を取得するために使用されます。この状態は、現在の乱数生成シーケンスにおける次の乱数の値を決定するために使用されます。この関数は、主に以下の2つの目的で使用されます。

  1. 乱数生成の再現性確保
    同じ状態から開始すれば、同じ乱数シーケンスを繰り返し生成することができます。これは、デバッグや研究において重要です。
  2. 異なるデバイス間での乱数生成の同期
    複数のデバイスで並列に処理を実行する場合、各デバイスで同じ乱数シーケンスを使用する必要があります。この関数は、各デバイスの状態を同期するために使用することができます。

関数詳細

torch.random.get_rng_state() -> torch.ByteTensor

この関数は、現在の乱数生成器の状態を torch.ByteTensor として返します。このテンソルは、CPU 上のデフォルトジェネレータの状態のみを保持します。GPU 上のジェネレータの状態を取得するには、torch.cuda.get_rng_state() 関数を使用する必要があります。

以下の例は、torch.random.get_rng_state()torch.random.set_rng_state() 関数を使用して、乱数生成の再現性を示します。

import torch

# 乱数シードを設定
torch.manual_seed(1234)

# 乱数を生成
x1 = torch.randn(10)

# 乱数生成器の状態を取得
state = torch.random.get_rng_state()

# 乱数生成器の状態を設定
torch.random.set_rng_state(state)

# 同じ状態から同じ乱数を生成
x2 = torch.randn(10)

print(x1)
print(x2)

このコードを実行すると、x1x2 は同じ値を出力します。これは、同じ乱数生成器の状態から乱数が生成されているためです。

応用例

torch.random.get_rng_state() 関数は、様々な場面で使用することができます。以下に、いくつかの例を挙げます。

  • 並列処理
    複数のデバイスで並列に処理を実行する場合、この関数は各デバイスで同じ乱数シーケンスを使用するために使用することができます。これにより、異なるデバイス間で結果の一貫性を保つことができます。
  • 研究
    研究において、ランダムな値がモデルのパフォーマンスに与える影響を調査する場合、この関数は役立ちます。同じ乱数シーケンスを繰り返し生成することで、結果の信頼性を高めることができます。
  • デバッグ
    モデルの動作がランダムな値に依存している場合、この関数はデバッグに役立ちます。同じ乱数シーケンスを繰り返し生成することで、問題を特定しやすくなります。
  • 乱数生成器の状態は、バイナリデータとして保存されます。このデータを改ざんすると、乱数生成の結果が予期せぬものになる可能性があります。
  • この関数は、CPU 上のデフォルトジェネレータの状態のみを返します。GPU 上のジェネレータの状態を取得するには、torch.cuda.get_rng_state() 関数を使用する必要があります。


乱数生成と状態取得

import torch

# 乱数シードを設定
torch.manual_seed(1234)

# 乱数を生成
x = torch.randn(10)

# 乱数生成器の状態を取得
state = torch.random.get_rng_state()

print(x)
print(state)

このコードを実行すると、以下の出力が得られます。

tensor([ 0.8031,  0.0349, -0.0843, -1.2051,  1.0492,  0.7436,  0.2302,
        0.9003, -0.3448, -0.2773])
torch.ByteTensor([  1, 174, 234,  77, 161,  40, 186, 249,  66,  33])
  • 2番目の行は、乱数生成器の状態 state をバイナリデータとして表示します。
  • 最初の行は、乱数生成されたテンソル x を表示します。

状態設定と再現性検証

import torch

# 乱数シードを設定
torch.manual_seed(1234)

# 乱数を生成
x1 = torch.randn(10)

# 乱数生成器の状態を取得
state = torch.random.get_rng_state()

# 乱数生成器の状態を設定
torch.random.set_rng_state(state)

# 同じ状態から同じ乱数を生成
x2 = torch.randn(10)

print(x1)
print(x2)
tensor([ 0.8031,  0.0349, -0.0843, -1.2051,  1.0492,  0.7436,  0.2302,
        0.9003, -0.3448, -0.2773])
tensor([ 0.8031,  0.0349, -0.0843, -1.2051,  1.0492,  0.7436,  0.2302,
        0.9003, -0.3448, -0.2773])
  • 3番目の行と4番目の行は同じ値を出力しており、これは同じ状態から同じ乱数が生成されていることを示しています。
  • 2番目の行は、torch.random.set_rng_state() 関数を使用して状態 state を設定した後に生成されたテンソル x2 を表示します。
  • 最初の行は、乱数生成されたテンソル x1 を表示します。
import torch

# 乱数シードを設定
torch.manual_seed(1234)

# GPU デバイスを作成
device = torch.device("cuda")

# 乱数を生成 (GPU上)
x = torch.randn(10, device=device)

# 乱数生成器の状態を取得 (GPU上)
state = torch.cuda.get_rng_state()

print(x)
print(state)

このコードを実行するには、GPU が搭載されている環境が必要です。

tensor([ 1.2091, -0.6623, -0.4043,  0.4751,  0.0394, -0.1507, -0.7934,
        0.7892,  0.5001,  1.0194], device='cuda:0')
torch.ByteTensor([  1, 174, 234,  77, 161,  40, 186, 249,  66,  33])
  • 最初の行は、GPU 上で生成されたテンソル


Generatorオブジェクトの複製

PyTorch 1.7.0 以降では、torch.random.get_rng_state() を使用せずに乱数生成器の状態を複製する方法があります。これは、Generator オブジェクトのコピーを作成することで実現できます。

import torch

# 乱数シードを設定
torch.manual_seed(1234)

# 乱数生成器を作成
rng = torch.random.default_generator()

# 乱数を生成
x1 = rng.normal(10)

# 乱数生成器を複製
rng2 = rng.clone()

# 複製された乱数生成器から乱数を生成
x2 = rng2.normal(10)

print(x1)
print(x2)

このコードを実行すると、x1x2 は同じ値を出力します。これは、同じ乱数生成器の状態から乱数が生成されていることを示しています。

torch.random.fork_rng() コンテキストマネージャー

torch.random.fork_rng() コンテキストマネージャーを使用して、乱数生成器の状態を一時的に保存し、後で復元することができます。これは、複数の処理間で乱数生成の一貫性を保ちたい場合に役立ちます。

import torch

# 乱数シードを設定
torch.manual_seed(1234)

with torch.random.fork_rng():
  # 乱数を生成
  x1 = torch.randn(10)

# 乱数生成器の状態を復元
torch.random.fork_rng().set_rng_state(torch.random.get_rng_state())

# 同じ状態から同じ乱数を生成
x2 = torch.randn(10)

print(x1)
print(x2)

このコードを実行すると、x1x2 は同じ値を出力します。これは、torch.random.fork_rng() コンテキストマネージャーを使用して乱数生成器の状態を保存および復元しているためです。

カスタム乱数状態管理

上記の方法で十分でない場合は、カスタムの乱数状態管理ロジックを実装することもできます。これは、複雑なワークフローや高度な制御が必要な場合に役立ちます。

  • カスタム乱数状態管理を実装する場合は、乱数生成の一貫性と再現性を確保するために注意する必要があります。
  • 上記の代替方法は、PyTorch 1.7.0 以降でのみ使用できます。