【超解説】PyTorch NNファンクション nll_lossの使い方:サンプルコード付き


具体的な動作

torch.nn.functional.nll_loss は、以下の 2 つの引数を受け取ります。

  1. input (Tensor)
    モデルが出力した非正規化された対数確率分布を表すテンソルです。形状は (N, C) または (N, C, H, W) になります。ここで、
    • N はバッチサイズ
    • C はクラス数
    • H は入力の特徴マップの高さ (2D 損失の場合のみ)
    • W は入力の特徴マップの幅 (2D 損失の場合のみ)
  2. target (Tensor)
    正解ラベルを表すテンソルです。形状は (N) または (N, H, W) になります。ここで、各要素は 0 から C-1 までの整数値を表します。

この関数は、以下の式に基づいて損失を計算します。

loss = -∑(i = 0; i < N; i++) log(input[i, target[i]])

ここで、log は自然対数関数を表します。

オプション引数

torch.nn.functional.nll_loss には、以下のオプション引数が用意されています。

  • ignore_index (int)
    無視するラベルインデックスを指定する整数です。デフォルトは -100 で、この場合対応する要素は損失計算から除外されます。
  • reduction (str)
    損失の集約方法を指定する文字列です。デフォルトは 'mean' で、この場合バッチ内の損失の平均が計算されます。他のオプションとしては 'sum' や 'none' があります。
  • weight (Tensor)
    各クラスに対する損失の重み付けを指定するテンソルです。形状は (C) になります。デフォルトは None で、この場合すべてのクラスに同じ重みが適用されます。

以下のコード例は、torch.nn.functional.nll_loss を使って損失を計算する方法を示しています。

import torch
import torch.nn.functional as F

# 入力データと正解ラベルを作成
input = torch.randn(10, 3)
target = torch.LongTensor([1, 2, 0, 2, 0, 1, 2, 0, 1, 2])

# 損失を計算
loss = F.nll_loss(input, target)
print(loss)

このコードを実行すると、以下の出力が得られます。

tensor(1.4963)


import torch
import torch.nn.functional as F

# データの準備
num_classes = 3  # クラス数
batch_size = 10  # バッチサイズ

# 入力データ (バッチサイズ x クラス数)
input = torch.randn(batch_size, num_classes)

# 正解ラベル (バッチサイズ)
target = torch.LongTensor([1, 2, 0, 2, 0, 1, 2, 0, 1, 2])

# 損失の計算
loss = F.nll_loss(input, target)
print(f"損失: {loss}")

# 各クラスに対する損失の重み付け
weights = torch.tensor([2., 1., 3.])  # クラス0: 2, クラス1: 1, クラス2: 3

# 重み付き損失の計算
weighted_loss = F.nll_loss(input, target, weight=weights)
print(f"重み付き損失: {weighted_loss}")

# 無視インデックスの設定
ignore_index = 0  # クラス0を無視

# 無視インデックス付き損失の計算
loss_with_ignore_index = F.nll_loss(input, target, ignore_index=ignore_index)
print(f"無視インデックス付き損失: {loss_with_ignore_index}")
  1. num_classesbatch_size 変数を定義して、クラス数とバッチサイズを設定します。
  2. inputtarget 変数を作成して、入力データと正解ラベルをランダムな値で初期化します。
  3. F.nll_loss を使用して損失を計算し、結果をコンソールに出力します。
  4. weights 変数を作成して、各クラスに対する損失の重み付けを設定します。
  5. 重み付き損失を計算し、結果をコンソールに出力します。
  6. ignore_index 変数を作成して、無視するラベルインデックスを設定します。
  7. 無視インデックス付き損失を計算し、結果をコンソールに出力します。
  • 損失の計算以外にも、torch.nn.functional.nll_loss 関数は、クラス確率の対数尤度や予測分布のエントロピーなどの値を計算するために使用できます。
  • このコードは PyTorch 1.9.0 で動作確認しています。


torch.nn.CrossEntropyLoss

torch.nn.CrossEntropyLoss は、torch.nn.functional.nll_loss とほぼ同等の機能を持つ関数ですが、以下の点で利点があります。

  • コードの簡潔化
    これにより、コードがより簡潔になり、読みやすくなります。
  • LogSoftmax との組み合わせが不要
    torch.nn.functional.nll_loss は、入力に対して LogSoftmax 関数を適用する必要がありますが、torch.nn.CrossEntropyLoss は内部で LogSoftmax を適用するため、このステップが不要になります。

一方、torch.nn.CrossEntropyLoss には、以下の欠点もあります。

  • 古いバージョンの PyTorch では利用不可
    PyTorch 1.6 以前のバージョンの PyTorch では利用できません。
  • オプション引数が少ない
    torch.nn.functional.nll_loss に比べてオプション引数が少ないため、柔軟性が低くなります。

以下のコード例は、torch.nn.functional.nll_losstorch.nn.CrossEntropyLoss を使用して損失を計算する方法を比較しています。

import torch
import torch.nn.functional as F

# データの準備
num_classes = 3
batch_size = 10

input = torch.randn(batch_size, num_classes)
target = torch.LongTensor([1, 2, 0, 2, 0, 1, 2, 0, 1, 2])

# nll_loss を使用した損失計算
nll_loss = F.nll_loss(F.log_softmax(input, dim=1), target)
print(f"nll_loss: {nll_loss}")

# CrossEntropyLoss を使用した損失計算
cross_entropy_loss = F.cross_entropy(input, target)
print(f"CrossEntropyLoss: {cross_entropy_loss}")

このコードを実行すると、両方の方法で同じ損失値が出力されることが確認できます。

独自の損失計算ロジックが必要な場合は、カスタム損失関数を作成することができます。これは、複雑な損失関数や、特定のタスクに特化した損失関数を定義する場合に役立ちます。

カスタム損失関数は、torch.nn.Module を継承したクラスとして定義できます。このクラスには、損失を計算する forward メソッドを実装する必要があります。

以下のコード例は、カスタム損失関数の簡単な例を示しています。

import torch
import torch.nn as nn

class MyLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        # カスタム損失計算ロジックを実装
        loss = 0.0
        for i in range(len(input)):
            loss += (input[i] - target[i])**2
        return loss

# データの準備
num_classes = 3
batch_size = 10

input = torch.randn(batch_size, num_classes)
target = torch.LongTensor([1, 2, 0, 2, 0, 1, 2, 0, 1, 2])

# カスタム損失関数を用いた損失計算
criterion = MyLoss()
loss = criterion(input, target)
print(f"カスタム損失: {loss}")

このコードは、二乗平均誤差に基づいたカスタム損失関数の例です。

上記以外にも、Kullback-Leibler ダイバージェンスやジェンセン-シャノン ダイバージェンスなどの情報理論に基づいた損失関数を使用することもできます。これらの損失関数は、クラス確率分布間の距離を測定するために役立ちます。