PyTorch でランダム数を再現可能にする: `torch.Generator.get_state()` 関数


この関数の動作

torch.Generator.get_state() 関数は、現在のジェネレータの状態を torch.ByteTensor として返します。このテンソルには、ジェネレータの状態を復元するために必要な情報が含まれています。

この関数の使用例

以下の例は、torch.Generator.get_state() 関数を使用して、ランダムな整数を生成し、その状態を保存および復元する方法を示しています。

import torch

# ジェネレータを作成
generator = torch.Generator()

# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)

# ジェネレータの状態を取得
state = generator.get_state()

# 新しいジェネレータを作成
new_generator = torch.Generator()

# 保存した状態を設定
new_generator.set_state(state)

# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)

この例では、最初のランダムな整数と 2 番目のランダムな整数は同じであることがわかります。これは、torch.Generator.set_state() 関数を使用してジェネレータの状態を復元したためです。

  • ジェネレータの状態は、異なるバージョンの PyTorch 間で互換性がない可能性があります。異なるバージョンの PyTorch 間でランダムな結果を再現するには、シード値を使用する必要があります。
  • torch.Generator.get_state() 関数は、CPU 上のデフォルトのジェネレータの状態のみを取得します。他のジェネレータの状態を取得するには、generator.get_state() メソッドを使用する必要があります。


import torch

# ジェネレータを作成
generator = torch.Generator()

# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)

# ジェネレータの状態を取得
state = generator.get_state()

# 新しいジェネレータを作成
new_generator = torch.Generator()

# 保存した状態を設定
new_generator.set_state(state)

# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)

このコードを実行すると、以下の出力が得られます。

tensor([8 4 2 9 5 0 7 3 1 6])
tensor([8 4 2 9 5 0 7 3 1 6])

一様分布からのランダムな浮動小数点数を生成する

import torch

# ジェネレータを作成
generator = torch.Generator()

# ランダムな浮動小数点を作成
random_floats = torch.rand(10, generator=generator)
print(random_floats)
import torch

# ジェネレータを作成
generator = torch.Generator()

# 法則分布からのランダムな整数を作成
random_integers = torch.randint(2, 10, (10,), generator=generator)
print(random_integers)
import torch

# ジェネレータを作成
generator = torch.Generator()

# 正規分布からのランダムな浮動小数点を作成
random_normals = torch.randn(10, generator=generator)
print(random_normals)


シード値を使用する

PyTorch Generator の状態は、シード値を使用して初期化されます。シード値がわかれば、torch.manual_seed() 関数を使用してジェネレータを初期化し、同じランダムな結果を再現することができます。

import torch

# シード値を設定
seed = 1234

# シード値を使用してジェネレータを初期化
generator = torch.Generator(manual_seed=seed)

# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)

# 同じシード値を使用して新しいジェネレータを初期化
new_generator = torch.Generator(manual_seed=seed)

# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)

この例では、最初のランダムな整数と 2 番目のランダムな整数は同じであることがわかります。これは、同じシード値を使用してジェネレータを初期化したためです。

torch.get_rng_state() 関数を使用する

torch.get_rng_state() 関数は、現在のデフォルトの RNG (Random Number Generator) の状態を取得するために使用されます。この状態を使用して、PyTorch Generator の状態を復元することができます。

import torch

# デフォルトの RNG の状態を取得
state = torch.get_rng_state()

# 新しいジェネレータを作成
new_generator = torch.Generator()

# 保存した状態を設定
new_generator.set_state(state)

# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)

この例では、torch.get_rng_state() 関数を使用して取得した状態を使用してジェネレータを初期化するため、ランダムな整数は前の例と同じになります。

カスタム状態クラスを使用する

独自のジェネレータ状態クラスを作成することもできます。このクラスは、必要な情報をすべて格納する必要があります。

import torch

class MyGeneratorState(object):
    def __init__(self, state):
        self.state = state

    def __repr__(self):
        return f"MyGeneratorState(state={self.state})"

# ジェネレータを作成
generator = torch.Generator()

# ランダムな整数を作成
random_numbers = torch.randint(0, 10, (10,), generator=generator)
print(random_numbers)

# ジェネレータの状態を取得
state = MyGeneratorState(generator.get_state())
print(state)

# 新しいジェネレータを作成
new_generator = torch.Generator()

# 保存した状態を設定
new_generator.set_state(state.state)

# 新しいジェネレータを使用してランダムな整数を作成
new_random_numbers = torch.randint(0, 10, (10,), generator=new_generator)
print(new_random_numbers)

この例では、MyGeneratorState というカスタム状態クラスを作成し、ジェネレータの状態を格納しています。この状態を使用して、新しいジェネレータを初期化し、同じランダムな結果を再現することができます。

torch.Generator.get_state() 関数以外にも、PyTorch Generator の状態を取得するにはいくつかの方法があります。それぞれの方法には長所と短所があるため、要件に応じて適切な方法を選択する必要があります。

  • カスタム状態クラスを使用する
    最も柔軟な方法ですが、実装が複雑になります。
  • torch.get_rng_state() 関数を使用する:** デフォルトの RNG の状態を取得できますが、他のジェネレータの状態を取得するには使用できません。
  • シード値を使用する
    最も簡単で効率的な方法ですが、ランダムな結果を完全に制御することはできません。