PyTorchのtorch.maskedで効率的なテンソル操作!実践的なコード例

2025-05-27

torch.masked モジュールには、いくつかの関数が含まれていますが、特によく使われるのは以下のものです。

  • torch.masked.masked_select(tensor, mask): maskTrue である要素に対応する tensor の要素を1次元のテンソルとして返します。masktensor は同じ形状である必要があります。


    import torch
    
    tensor = torch.arange(4).reshape(2, 2)
    mask = tensor % 2 == 0
    
    result = torch.masked.masked_select(tensor, mask)
    print(result)
    # 出力:
    # tensor([0, 2])
    

    この例では、masktensor の要素が偶数の位置が True)の True の位置に対応する tensor の要素(0 と 2)が1次元のテンソルとして抽出されています。

  • torch.masked.masked_fill(tensor, mask, value): maskTrue である要素の位置に対応する tensor の要素を、指定された value で埋めます。masktensor は同じ形状である必要があります。


    import torch
    
    tensor = torch.arange(4).reshape(2, 2).float()
    mask = tensor > 1
    
    torch.masked.masked_fill(tensor, mask, -1)
    print(tensor)
    # 出力:
    # tensor([[ 0.,  1.],
    #         [-1., -1.]])
    

    この例では、masktensor の要素が 1 より大きい位置が True)の True の位置に対応する tensor の要素が -1 で埋められています。

  • torch.masked.masked_scatter(tensor, mask, source): maskTrue である要素の位置に対応する tensor の要素を、source テンソルの要素で順番に上書きします。masktensor は同じ形状である必要があります。source は 1次元のテンソルであるか、maskTrue となる要素数と同じ要素数を持つテンソルである必要があります。


    import torch
    
    tensor = torch.zeros(2, 2)
    mask = torch.tensor([[True, False], [False, True]])
    source = torch.tensor([1, 2])
    
    torch.masked.masked_scatter(tensor, mask, source)
    print(tensor)
    # 出力:
    # tensor([[1., 0.],
    #         [0., 2.]])
    

    この例では、maskTrue の位置((0, 0) と (1, 1))に対応する tensor の要素が、source の要素(1 と 2)で順番に置き換えられています。

これらの関数を使うことで、特定の条件を満たすテンソルの要素に対して効率的に操作を行うことができます。例えば、損失関数の計算時に特定の値を除外したり、ある閾値を超える値だけを処理したりする際に役立ちます。



マスク(mask)の形状に関するエラー

  • トラブルシューティング
    • マスクテンソルと操作対象のテンソルの形状を注意深く確認してください。
    • unsqueeze(), squeeze(), expand(), reshape() などを用いて、マスクテンソルの形状を操作対象のテンソルの形状に合うように調整してください。
    • 特に masked_scatter の場合、マスクが True の要素数と source テンソルの要素数が一致している必要もあります。
  • 原因
    torch.masked の関数(masked_scatter, masked_fill, masked_select など)では、通常、操作対象のテンソルとマスクテンソルの形状が一部または完全に一致している必要があります。このエラーは、指定された次元において形状が一致していない場合に発生します。
  • エラー
    RuntimeError: The shape of the mask [形状1] at index [インデックス] must match the shape of the tensor [形状2] at index [インデックス]

データ型(dtype)に関するエラー

  • トラブルシューティング
    • マスクテンソルのデータ型を確認し、.bool() メソッドを用いて boolean 型に変換してください。
    • 例: mask = (tensor > 0).bool()
  • 原因
    マスクテンソルは通常、boolean型 (torch.bool) である必要があります。他のデータ型(例えば torch.uint8, torch.int64 など)をマスクとして使用すると、予期しない動作をしたり、エラーが発生したりする可能性があります。
  • エラー
    (具体的なエラーメッセージは状況によりますが、データ型の不一致を示唆する内容が多いです)

masked_scatter における source テンソルの要素数に関するエラー

  • トラブルシューティング
    • マスクが True となる要素の数をカウントし (mask.sum())、source テンソルの要素数がその数と一致しているか確認してください。
    • 必要に応じて、source テンソルを適切な要素数に調整するか、マスクの条件を見直してください。
  • 原因
    torch.masked.masked_scatter では、マスクが True である要素の数と、source テンソルの要素数が一致している必要があります。一致していない場合、どの要素をどの位置に書き込むかが曖昧になるためエラーが発生します。
  • エラー
    RuntimeError: The number of elements in source [要素数1] must match the number of masked elements [要素数2]

masked_select の結果の形状に関する誤解

  • トラブルシューティング
    • masked_select の結果は常に1次元のテンソルであることを理解してください。
    • 元の形状を保持したい場合は、masked_scatter や他の方法を検討する必要があります。
  • 原因
    torch.masked.masked_select は、マスクが True である要素を1次元のテンソルとして返します。元のテンソルの形状は保持されません。
  • 問題
    masked_select の結果が元のテンソルと同じ形状ではない。

ブロードキャスティングに関する注意点

  • トラブルシューティング
    • 明示的に形状を一致させるようにコードを記述することを推奨します。
    • ブロードキャスティングが意図通りに機能しているか確認するために、中間的なテンソルの形状を出力して確認してください。
  • 原因
    PyTorchのブロードキャスティングルールに従い、形状が異なるテンソル間でも特定の条件下で演算が可能になりますが、torch.masked の関数では、形状が明確に一致していることが期待される場合が多いです。
  • 問題
    マスクテンソルと操作対象のテンソルの形状が完全に一致しない場合、ブロードキャスティングが適用されることがありますが、意図しない結果になることがあります。
  • PyTorchのドキュメントを参照する
    torch.masked の各関数の詳細な仕様や例は、PyTorchの公式ドキュメントで確認できます。
  • 小さな例で試す
    問題が複雑な場合は、簡単なテンソルとマスクを作成して、意図した動作になるか試してみるのが有効です。
  • テンソルの形状とデータ型を常に意識する
    print(tensor.shape), print(tensor.dtype) などを用いて、テンソルの情報を確認することが重要です。
  • エラーメッセージを注意深く読む
    エラーメッセージには、問題の原因や場所に関する重要な情報が含まれています。


torch.masked.masked_scatter の例:特定の条件を満たす要素を別のテンソルから書き込む

import torch

# 元のテンソル
tensor = torch.zeros(5)
print("元のテンソル:", tensor)

# マスク(True の位置に書き込みを行う)
mask = torch.tensor([True, False, True, False, True])
print("マスク:", mask)

# 書き込む値を持つテンソル
source = torch.tensor([10, 20, 30])
print("書き込む値:", source)

# masked_scatter を使用して、マスクが True の位置に source の値を順番に書き込む
torch.masked.masked_scatter(tensor, mask, source)
print("masked_scatter 後のテンソル:", tensor)

説明

この例では、tensor のうち、maskTrue であるインデックス(0, 2, 4)に対応する要素が、source テンソルの要素(10, 20, 30)で順番に上書きされます。source テンソルの要素数は、maskTrue となる要素の数と一致している必要があります。

torch.masked.masked_fill の例:特定の条件を満たす要素を特定の値で埋める

import torch

# テンソル
tensor = torch.arange(10).float().reshape(2, 5)
print("元のテンソル:\n", tensor)

# マスク(要素が 5 より大きい位置を True とする)
mask = tensor > 5
print("マスク:\n", mask)

# 埋める値
fill_value = -1.0

# masked_fill を使用して、マスクが True の位置を fill_value で埋める
new_tensor = torch.masked.masked_fill(tensor, mask, fill_value)
print("masked_fill 後のテンソル:\n", new_tensor)

説明

この例では、tensor の要素のうち、maskTrue である(つまり、5 より大きい)位置の要素が -1.0 で置き換えられた新しいテンソル new_tensor が作成されます。元の tensor は変更されません(masked_fill は新しいテンソルを返します)。

torch.masked.masked_select の例:特定の条件を満たす要素を抽出する

import torch

# テンソル
tensor = torch.randn(3, 4)
print("元のテンソル:\n", tensor)

# マスク(絶対値が 0.5 より大きい位置を True とする)
mask = torch.abs(tensor) > 0.5
print("マスク:\n", mask)

# masked_select を使用して、マスクが True の要素を抽出する
selected_elements = torch.masked.masked_select(tensor, mask)
print("masked_select で抽出された要素:\n", selected_elements)
print("抽出された要素の形状:", selected_elements.shape)

説明

この例では、tensor の要素のうち、maskTrue である(つまり、絶対値が 0.5 より大きい)要素が抽出され、1次元のテンソル selected_elements として返されます。元のテンソルの形状は保持されません。

応用例:損失関数の計算における特定の値の無視

import torch
import torch.nn.functional as F

# 予測値と正解ラベル
logits = torch.randn(5, 3)
targets = torch.randint(0, 3, (5,))
print("予測値 (logits):\n", logits)
print("正解ラベル (targets):\n", targets)

# 無視するラベルの値(例えば -1)
ignore_index = -1

# 無視する位置を示すマスクを作成
mask = targets != ignore_index
print("無視しない要素のマスク:\n", mask)

# マスクを使用して、無視しない要素の損失のみを計算
loss = F.cross_entropy(logits, targets, ignore_index=ignore_index, reduction='none')
masked_loss = torch.masked.masked_select(loss, mask)
print("各要素の損失:\n", loss)
print("無視しない要素の損失:\n", masked_loss)

# 平均損失を計算(無視された要素は含まれない)
mean_loss = masked_loss.mean()
print("平均損失 (無視された要素を除く):", mean_loss)

説明

この例は、損失関数を計算する際に、特定のラベル(ここでは -1)を無視する方法を示しています。まず、無視しない要素に対応するマスクを作成し、そのマスクを使って torch.masked.masked_select で損失値を選択的に抽出しています。これにより、無視したい要素の損失が平均損失の計算に含まれなくなります。F.cross_entropyignore_index パラメータを指定する方法もありますが、torch.masked を使うことでより柔軟な条件での要素選択が可能になります。



booleanインデックス参照(Boolean Indexing)

最も一般的で強力な代替方法は、boolean型のテンソルをインデックスとして使用する方法です。これは、torch.masked.masked_select と同様の要素抽出や、torch.masked.masked_fill のような特定条件下の要素の書き換えを実現できます。

例1: torch.masked.masked_select の代替

import torch

tensor = torch.randn(3, 4)
mask = torch.abs(tensor) > 0.5

# booleanインデックス参照で True の要素を抽出
selected_elements = tensor[mask]
print("抽出された要素:\n", selected_elements)

説明

tensor[mask] のように、boolean型の mask テンソルをインデックスとして使用すると、mask の値が True である位置に対応する tensor の要素が1次元のテンソルとして返されます。これは torch.masked.masked_select(tensor, mask) と同じ結果になります。

例2: torch.masked.masked_fill の代替

import torch

tensor = torch.arange(10).float().reshape(2, 5)
mask = tensor > 5
fill_value = -1.0

# booleanインデックス参照で True の位置に値を代入
new_tensor = tensor.clone()  # 元のテンソルを保護するためにコピーを作成
new_tensor[mask] = fill_value
print("masked_fill の代替後のテンソル:\n", new_tensor)

説明

tensor[mask] = value のように、boolean型の mask を左辺に用いると、maskTrue である位置の tensor の要素に value を代入できます。torch.masked.masked_fill と異なり、この方法は元のテンソルを直接変更するため、必要に応じて .clone() でコピーを作成する必要があります。

torch.where 関数

torch.where(condition, x, y) 関数は、condition (booleanテンソル) が True の要素に対しては x の対応する要素を、False の要素に対しては y の対応する要素を持つ新しいテンソルを返します。これは、条件に基づいて要素を選択的に変更する際に便利で、torch.masked.masked_fill の一部の機能と似たことができます。


import torch

tensor = torch.arange(10).float().reshape(2, 5)
mask = tensor > 5
fill_value = -1.0

# torch.where を使用して、True の位置を fill_value で、False の位置を元の値で埋める
new_tensor = torch.where(mask, torch.full_like(tensor, fill_value), tensor)
print("torch.where を使用した結果:\n", new_tensor)

説明

この例では、maskTrue の位置には fill_value と同じ形状のテンソルの要素が、False の位置には元の tensor の要素がそのまま使われた新しいテンソル new_tensor が作成されます。

ループ処理(Pythonの制御構造)

PyTorchのテンソル操作は通常、効率のためにベクトル化されていますが、小規模な処理や複雑な条件分岐が必要な場合には、Pythonの for ループや if 文などを組み合わせて処理することも可能です。ただし、パフォーマンスの観点からは、できる限りテンソル演算を用いることが推奨されます。


import torch

tensor = torch.arange(5).float()
mask = torch.tensor([True, False, True, False, True])
fill_value = -1.0

new_list = []
for i in range(tensor.size(0)):
    if mask[i]:
        new_list.append(fill_value)
    else:
        new_list.append(tensor[i].item())

new_tensor = torch.tensor(new_list)
print("ループ処理の結果:", new_tensor)

説明

この例では、ループを使って各要素をチェックし、マスクが True なら fill_value を、False なら元の値をリストに追加し、最後にテンソルに変換しています。これは torch.masked.masked_fill の簡単な代替ですが、大きなテンソルに対しては非効率です。

torch.masked.masked_scatter の代替について

torch.masked.masked_scatter の完全に直接的な代替は、booleanインデックス参照だけでは少し複雑になります。なぜなら、masked_scattersource テンソルの要素を順番に書き込むからです。しかし、もし書き込む値が特定の値で一定であれば、booleanインデックス参照で代用できます。

もし source テンソルから順番に書き込む必要がある場合は、マスクを使って書き込むべきインデックスを取得し、そのインデックスに対して source の要素を代入する処理を組み合わせる必要があります。

import torch

tensor = torch.zeros(5)
mask = torch.tensor([True, False, True, False, True])
source = torch.tensor([10, 20, 30])

# マスクが True のインデックスを取得
indices = torch.nonzero(mask).squeeze(1)
print("True のインデックス:", indices)

# これらのインデックスに source の値を代入
tensor[indices] = source
print("代替後のテンソル:", tensor)

説明

ここでは、torch.nonzero(mask)maskTrue のインデックスを取得し、.squeeze(1) で形状を調整しています。その後、このインデックスを使って tensorsource の値を代入しています。ただし、この方法は source の要素数とマスクの True の数が一致していることを前提としています。