TensorBoard で Precision-Recall 曲線を可視化: `torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve()` の徹底解説
この関数の主な役割は以下の通りです
- 曲線と情報の可視化
Tensorboard を起動することで、記録された Precision-Recall 曲線と関連する情報を可視化することができます。 - Tensorboard への記録
生成された Precision-Recall 曲線と関連する情報を Tensorboard に記録します。 - Precision-Recall 曲線の生成
予測確率と真のラベルに基づいて、Precision-Recall 曲線を生成します。
この関数の引数は以下の通りです
- num_thresholds (int, optional)
Precision-Recall 曲線を作成するために使用する閾値の数です。デフォルトは 127 です。 - global_step (int, optional)
イベントを記録する際のグローバルステップ値です。デフォルトは None です。 - predictions (torch.Tensor, numpy.ndarray, or str/blobname)
予測確率を表すテンソル、NumPy 配列、または文字列です。 - labels (torch.Tensor, numpy.ndarray, or str/blobname)
真のラベルを表すテンソル、NumPy 配列、または文字列です。 - tag (str)
曲線に付けるタグ名です。
この関数の例は以下の通りです
import torch
import torch.utils.tensorboard as tb
# 予測確率と真のラベルを生成
predictions = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels = torch.tensor([1, 0, 1, 0, 1])
# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')
# Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve', labels, predictions)
# Tensorboard を起動
tensorboard --logdir logs
- Precision-Recall 曲線以外にも、ROC 曲線や F1 スコアなどの指標を Tensorboard で可視化することができます。
torch.utils.tensorboard.writer.SummaryWriter.add_pr_curve()
関数は、PyTorch 1.1 以降で使用できます。
例 1: 複数の曲線を同時に追加
この例では、異なるモデルの予測確率と真のラベルを使用して、複数の Precision-Recall 曲線を同時に Tensorboard に追加します。
import torch
import torch.utils.tensorboard as tb
# 2つのモデルの予測確率と真のラベルを生成
predictions_model1 = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels_model1 = torch.tensor([1, 0, 1, 0, 1])
predictions_model2 = torch.tensor([0.2, 0.5, 0.8, 0.3, 0.7])
labels_model2 = torch.tensor([1, 0, 1, 0, 1])
# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')
# それぞれのモデルの Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve_model1', labels_model1, predictions_model1)
writer.add_pr_curve('pr_curve_model2', labels_model2, predictions_model2)
# Tensorboard を起動
tensorboard --logdir logs
例 2: カスタム閾値を使用して曲線を生成
この例では、num_thresholds
引数を使用して、Precision-Recall 曲線を作成するために使用する閾値の数を変えます。
import torch
import torch.utils.tensorboard as tb
# 予測確率と真のラベルを生成
predictions = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels = torch.tensor([1, 0, 1, 0, 1])
# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')
# 異なる閾値の数を使用して Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve_10_thresholds', labels, predictions, num_thresholds=10)
writer.add_pr_curve('pr_curve_50_thresholds', labels, predictions, num_thresholds=50)
# Tensorboard を起動
tensorboard --logdir logs
例 3: グローバルステップを指定
この例では、global_step
引数を使用して、イベントを記録する際のグローバルステップ値を指定します。
import torch
import torch.utils.tensorboard as tb
# 予測確率と真のラベルを生成
predictions = torch.tensor([0.1, 0.4, 0.9, 0.2, 0.6])
labels = torch.tensor([1, 0, 1, 0, 1])
# Tensorboard writer を作成
writer = tb.SummaryWriter('logs')
# グローバルステップを指定して Precision-Recall 曲線を Tensorboard に記録
writer.add_pr_curve('pr_curve', labels, predictions, global_step=100)
# Tensorboard を起動
tensorboard --logdir logs
matplotlib を使用した手動プロット
- 短所:
- Tensorboard に統合されていないため、ワークフローが煩雑になる可能性がある
- コードが冗長になり、メンテナンスが難しくなる可能性がある
- 長所:
- コードの制御とカスタマイズ性の自由度が高い
- 特定のニーズに合わせてグラフを個別に調整できる
import matplotlib.pyplot as plt
import numpy as np
# 予測確率と真のラベルを生成
predictions = np.array([0.1, 0.4, 0.9, 0.2, 0.6])
labels = np.array([1, 0, 1, 0, 1])
# Precision-Recall 曲線を計算
precision, recall, thresholds = precision_recall_curve(labels, predictions)
# Precision-Recall 曲線をプロット
plt.plot(thresholds, precision, label='Precision')
plt.plot(thresholds, recall, label='Recall')
plt.xlabel('Threshold')
plt.ylabel('Precision/Recall')
plt.title('Precision-Recall Curve')
plt.legend()
plt.show()
scikit-learn を使用した手動計算と可視化
- 短所:
- Tensorboard に統合されていないため、ワークフローが煩雑になる可能性がある
- カスタム要件に柔軟に対応できない場合がある
- 長所:
precision_recall_curve
関数など、Precision-Recall 曲線分析のための便利なユーティリティを提供plot_precision_recall_curve
関数を使用して、簡単にグラフを生成できる
from sklearn.metrics import precision_recall_curve
# 予測確率と真のラベルを生成
predictions = np.array([0.1, 0.4, 0.9, 0.2, 0.6])
labels = np.array([1, 0, 1, 0, 1])
# Precision-Recall 曲線を計算
precision, recall, thresholds = precision_recall_curve(labels, predictions)
# Precision-Recall 曲線をプロット
plt.plot(thresholds, precision, label='Precision')
plt.plot(thresholds, recall, label='Recall')
plt.xlabel('Threshold')
plt.ylabel('Precision/Recall')
plt.title('Precision-Recall Curve')
plt.legend()
plt.show()
サードパーティのライブラリを使用
- 短所:
- サードパーティ製ライブラリの追加インストールと設定が必要
- オープンソースではないライブラリは、商用利用にライセンス料金がかかる場合があります
- 長所:
Wandb
やComet ML
のようなライブラリは、Tensorboard と統合された拡張機能を提供し、Precision-Recall 曲線を含む豊富な可視化機能を提供- 実験管理、コラボレーション、結果の追跡などの追加機能を提供する場合がある
例
import wandb
# Wandb を初期化
wandb.init(project='my-project')
# 予測確率と真のラベルを生成
predictions = np.array([0.1, 0.4, 0.9, 0.2, 0.6])
labels = np.array([1, 0, 1, 0, 1])
# Precision-Recall 曲線をログ
wandb.log({"pr_curve": wandb.plots.precision_recall(predictions, labels)})
カスタムロジックの実装
- 短所:
- 複雑で時間がかかる場合がある
- 専門知識とデバッグスキルが必要
- 長所:
- 完全な制御と柔軟性
- 特定のニーズや要件に合わせてアルゴリズムを完全にカスタマイズできる
上記の代替方法はそれぞれ長所と短所があるため、状況に合わせて最適な方法を選択することが重要です。
- データの量と複雑性
大規模で複雑なデータセットを扱う場合は、scikit-learn
やサードパーティのライブラリなどのスケーラブルなソリューションが適している場合があります