PyTorch でランダム数を再現可能にする: `torch.Generator.get_state()` 関数
この関数の動作
torch.Generator.get_state()
関数は、現在のジェネレータの状態を torch.ByteTensor
として返します。このテンソルには、ジェネレータの状態を復元するために必要な情報が含まれています。
この関数の使用例
以下の例は、torch.Generator.get_state()
関数を使用して、ランダムな整数を生成し、その状態を保存および復元する方法を示しています。
import torch
# ジェネレータを作成
generator = torch.Generator()
# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)
# ジェネレータの状態を取得
state = generator.get_state()
# 新しいジェネレータを作成
new_generator = torch.Generator()
# 保存した状態を設定
new_generator.set_state(state)
# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)
この例では、最初のランダムな整数と 2 番目のランダムな整数は同じであることがわかります。これは、torch.Generator.set_state()
関数を使用してジェネレータの状態を復元したためです。
- ジェネレータの状態は、異なるバージョンの PyTorch 間で互換性がない可能性があります。異なるバージョンの PyTorch 間でランダムな結果を再現するには、シード値を使用する必要があります。
torch.Generator.get_state()
関数は、CPU 上のデフォルトのジェネレータの状態のみを取得します。他のジェネレータの状態を取得するには、generator.get_state()
メソッドを使用する必要があります。
import torch
# ジェネレータを作成
generator = torch.Generator()
# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)
# ジェネレータの状態を取得
state = generator.get_state()
# 新しいジェネレータを作成
new_generator = torch.Generator()
# 保存した状態を設定
new_generator.set_state(state)
# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)
このコードを実行すると、以下の出力が得られます。
tensor([8 4 2 9 5 0 7 3 1 6])
tensor([8 4 2 9 5 0 7 3 1 6])
一様分布からのランダムな浮動小数点数を生成する
import torch
# ジェネレータを作成
generator = torch.Generator()
# ランダムな浮動小数点を作成
random_floats = torch.rand(10, generator=generator)
print(random_floats)
import torch
# ジェネレータを作成
generator = torch.Generator()
# 法則分布からのランダムな整数を作成
random_integers = torch.randint(2, 10, (10,), generator=generator)
print(random_integers)
import torch
# ジェネレータを作成
generator = torch.Generator()
# 正規分布からのランダムな浮動小数点を作成
random_normals = torch.randn(10, generator=generator)
print(random_normals)
シード値を使用する
PyTorch Generator の状態は、シード値を使用して初期化されます。シード値がわかれば、torch.manual_seed()
関数を使用してジェネレータを初期化し、同じランダムな結果を再現することができます。
import torch
# シード値を設定
seed = 1234
# シード値を使用してジェネレータを初期化
generator = torch.Generator(manual_seed=seed)
# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)
# 同じシード値を使用して新しいジェネレータを初期化
new_generator = torch.Generator(manual_seed=seed)
# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)
この例では、最初のランダムな整数と 2 番目のランダムな整数は同じであることがわかります。これは、同じシード値を使用してジェネレータを初期化したためです。
torch.get_rng_state() 関数を使用する
torch.get_rng_state()
関数は、現在のデフォルトの RNG (Random Number Generator) の状態を取得するために使用されます。この状態を使用して、PyTorch Generator の状態を復元することができます。
import torch
# デフォルトの RNG の状態を取得
state = torch.get_rng_state()
# 新しいジェネレータを作成
new_generator = torch.Generator()
# 保存した状態を設定
new_generator.set_state(state)
# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)
この例では、torch.get_rng_state()
関数を使用して取得した状態を使用してジェネレータを初期化するため、ランダムな整数は前の例と同じになります。
カスタム状態クラスを使用する
独自のジェネレータ状態クラスを作成することもできます。このクラスは、必要な情報をすべて格納する必要があります。
import torch
class MyGeneratorState(object):
def __init__(self, state):
self.state = state
def __repr__(self):
return f"MyGeneratorState(state={self.state})"
# ジェネレータを作成
generator = torch.Generator()
# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)
# ジェネレータの状態を取得
state = MyGeneratorState(generator.get_state())
print(state)
# 新しいジェネレータを作成
new_generator = torch.Generator()
# 保存した状態を設定
new_generator.set_state(state.state)
# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)
この例では、MyGeneratorState
というカスタム状態クラスを作成し、ジェネレータの状態を格納しています。この状態を使用して、新しいジェネレータを初期化し、同じランダムな結果を再現することができます。
torch.Generator.get_state()
関数以外にも、PyTorch Generator の状態を取得するにはいくつかの方法があります。それぞれの方法には長所と短所があるため、要件に応じて適切な方法を選択する必要があります。
- カスタム状態クラスを使用する
最も柔軟な方法ですが、実装が複雑になります。 torch.get_rng_state()
関数を使用する:** デフォルトの RNG の状態を取得できますが、他のジェネレータの状態を取得するには使用できません。- シード値を使用する
最も簡単で効率的な方法ですが、ランダムな結果を完全に制御することはできません。