TensorBoard で Precision-Recall 曲線を可視化: `torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve()` の徹底解説


この関数の主な役割は以下の通りです

  • 曲線と情報の可視化
    Tensorboard を起動することで、記録された Precision-Recall 曲線と関連する情報を可視化することができます。
  • Tensorboard への記録
    生成された Precision-Recall 曲線と関連する情報を Tensorboard に記録します。
  • Precision-Recall 曲線の生成
    予測確率と真のラベルに基づいて、Precision-Recall 曲線を生成します。

この関数の引数は以下の通りです

  • num_thresholds (int, optional)
    Precision-Recall 曲線を作成するために使用する閾値の数です。デフォルトは 127 です。
  • global_step (int, optional)
    イベントを記録する際のグローバルステップ値です。デフォルトは None です。
  • predictions (torch.Tensor, numpy.ndarray, or str/blobname)
    予測確率を表すテンソル、NumPy 配列、または文字列です。
  • labels (torch.Tensor, numpy.ndarray, or str/blobname)
    真のラベルを表すテンソル、NumPy 配列、または文字列です。
  • tag (str)
    曲線に付けるタグ名です。

この関数の例は以下の通りです

import torch
import torch.utils.tensorboard as tb

# 予測確率と真のラベルを生成
predictions = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels = torch.tensor([1, 0, 1, 0, 1])

# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')

# Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve', labels, predictions)

# Tensorboard を起動
tensorboard --logdir logs
  • Precision-Recall 曲線以外にも、ROC 曲線や F1 スコアなどの指標を Tensorboard で可視化することができます。
  • torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve() 関数は、PyTorch 1.1 以降で使用できます。


例 1: 複数の曲線を同時に追加

この例では、異なるモデルの予測確率と真のラベルを使用して、複数の Precision-Recall 曲線を同時に Tensorboard に追加します。

import torch
import torch.utils.tensorboard as tb

# 2つのモデルの予測確率と真のラベルを生成
predictions_model1 = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels_model1 = torch.tensor([1, 0, 1, 0, 1])

predictions_model2 = torch.tensor([0.2, 0.5, 0.8, 0.3, 0.7])
labels_model2 = torch.tensor([1, 0, 1, 0, 1])

# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')

# それぞれのモデルの Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve_model1', labels_model1, predictions_model1)
writer.add_pr_curve('pr_curve_model2', labels_model2, predictions_model2)

# Tensorboard を起動
tensorboard --logdir logs

例 2: カスタム閾値を使用して曲線を生成

この例では、num_thresholds 引数を使用して、Precision-Recall 曲線を作成するために使用する閾値の数を変えます。

import torch
import torch.utils.tensorboard as tb

# 予測確率と真のラベルを生成
predictions = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels = torch.tensor([1, 0, 1, 0, 1])

# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')

# 異なる閾値の数を使用して Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve_10_thresholds', labels, predictions, num_thresholds=10)
writer.add_pr_curve('pr_curve_50_thresholds', labels, predictions, num_thresholds=50)

# Tensorboard を起動
tensorboard --logdir logs

例 3: グローバルステップを指定

この例では、global_step 引数を使用して、イベントを記録する際のグローバルステップ値を指定します。

import torch
import torch.utils.tensorboard as tb

# 予測確率と真のラベルを生成
predictions = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels = torch.tensor([1, 0, 1, 0, 1])

# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')

# グローバルステップを指定して Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve', labels, predictions, global_step=100)

# Tensorboard を起動
tensorboard --logdir logs


matplotlib を使用した手動プロット

  • 短所:
    • Tensorboard に統合されていないため、ワークフローが煩雑になる可能性がある
    • コードが冗長になり、メンテナンスが難しくなる可能性がある
  • 長所:
    • コードの制御とカスタマイズ性の自由度が高い
    • 特定のニーズに合わせてグラフを個別に調整できる
import matplotlib.pyplot as plt
import numpy as np

# 予測確率と真のラベルを生成
predictions = np.array([0.1, 0.4, 0.9, 0.2, 0.6])
labels = np.array([1, 0, 1, 0, 1])

# Precision-Recall 曲線を計算
precision, recall, thresholds = precision_recall_curve(labels, predictions)

# Precision-Recall 曲線をプロット
plt.plot(thresholds, precision, label='Precision')
plt.plot(thresholds, recall, label='Recall')
plt.xlabel('Threshold')
plt.ylabel('Precision/Recall')
plt.title('Precision-Recall Curve')
plt.legend()
plt.show()

scikit-learn を使用した手動計算と可視化

  • 短所:
    • Tensorboard に統合されていないため、ワークフローが煩雑になる可能性がある
    • カスタム要件に柔軟に対応できない場合がある
  • 長所:
    • precision_recall_curve 関数など、Precision-Recall 曲線分析のための便利なユーティリティを提供
    • plot_precision_recall_curve 関数を使用して、簡単にグラフを生成できる
from sklearn.metrics import precision_recall_curve

# 予測確率と真のラベルを生成
predictions = np.array([0.1, 0.4, 0.9, 0.2, 0.6])
labels = np.array([1, 0, 1, 0, 1])

# Precision-Recall 曲線を計算
precision, recall, thresholds = precision_recall_curve(labels, predictions)

# Precision-Recall 曲線をプロット
plt.plot(thresholds, precision, label='Precision')
plt.plot(thresholds, recall, label='Recall')
plt.xlabel('Threshold')
plt.ylabel('Precision/Recall')
plt.title('Precision-Recall Curve')
plt.legend()
plt.show()

サードパーティのライブラリを使用

  • 短所:
    • サードパーティ製ライブラリの追加インストールと設定が必要
    • オープンソースではないライブラリは、商用利用にライセンス料金がかかる場合があります
  • 長所:
    • WandbComet ML のようなライブラリは、Tensorboard と統合された拡張機能を提供し、Precision-Recall 曲線を含む豊富な可視化機能を提供
    • 実験管理、コラボレーション、結果の追跡などの追加機能を提供する場合がある


import wandb

# Wandb を初期化
wandb.init(project='my-project')

# 予測確率と真のラベルを生成
predictions = np.array([0.1, 0.4, 0.9, 0.2, 0.6])
labels = np.array([1, 0, 1, 0, 1])

# Precision-Recall 曲線をログ
wandb.log({"pr_curve": wandb.plots.precision_recall(predictions, labels)})

カスタムロジックの実装

  • 短所:
    • 複雑で時間がかかる場合がある
    • 専門知識とデバッグスキルが必要
  • 長所:
    • 完全な制御と柔軟性
    • 特定のニーズや要件に合わせてアルゴリズムを完全にカスタマイズできる

上記の代替方法はそれぞれ長所と短所があるため、状況に合わせて最適な方法を選択することが重要です

  • データの量と複雑性
    大規模で複雑なデータセットを扱う場合は、scikit-learn やサードパーティのライブラリなどのスケーラブルなソリューションが適している場合があります