PyTorchでSoftplus関数を使いこなす:MNISTデータセットを用いた手書き数字認識モデルの実装例
torch.nn.Softplus
は、以下の式で表される滑らかな近似 ReLU 関数です。
Softplus(x) = 1/β * log(1 + exp(β * x))
ここで、
β
はパラメータ(デフォルト値は 1)x
は入力値
Softplus
関数は、入力値が負のときでも常に非負の値を出力します。また、入力値が大きくなるにつれて、出力値は緩やかに増加していきます。
torch.nn.Softplus
関数は、以下の利点があります。
- ReLU 関数よりも滑らかな導関数を持つため、勾配消失問題が発生しにくいです。
以下は、torch.nn.Softplus
関数を使用する簡単な例です。
import torch
import torch.nn as nn
# ニューラルネットワークを作成します。
model = nn.Sequential(
nn.Linear(10, 20),
nn.Softplus(),
nn.Linear(20, 1)
)
# 入力データを作成します。
x = torch.randn(10, 10)
# ニューラルネットワークを実行します。
y = model(x)
# 出力結果を出力します。
print(y)
コード
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# データセットの読み込み
train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root="data", train=False, transform=transforms.ToTensor())
# データローダーの作成
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# ニューラルネットワークの構築
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.act1 = nn.Softplus()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.act1(x)
x = self.fc2(x)
return x
model = Net().to(device)
# 損失関数と最適化アルゴリズムの設定
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# モデルの訓練
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
# モデルの評価
model.eval()
with torch.no_grad():
correct = 0
total = 0
for data in test_loader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
コードの説明
- モデルの訓練後、MNIST テストデータセットを使用してモデルの精度が評価されます。
- モデルは、Adam 最適化アルゴリズムを使用して訓練されます。
torch.nn.Softplus
関数は、Net
クラスのforward
メソッドで活性化関数として使用されています。- このコードは、PyTorch 1.9.0 および Python 3.7 で動作確認済みです。
実行方法
このコードを実行するには、以下の手順を実行してください。
- Python と PyTorch をインストールします。
- コードを保存し、
mnist_softplus.py
のような名前で保存します。 - 以下のコマンドを実行して、コードを実行します。
python mnist_softplus.py
- モデルの訓練には時間がかかる場合があります。
- このコードはあくまで例であり、学習率やバッチサイズなどのハイパーパラメータを調整する必要がある場合があります。
torch.nn.ReLU
- 短所:
- 入力値が0以下の場合、出力値が0になってしまう(死んだニューロン問題)
- 勾配消失問題が発生しやすい
- 長所:
- シンプルで計算コストが低い
- 出力値が常に非負
torch.nn.SELU
- 短所:
ReLU
よりも計算コストが高い
- 長所:
ReLU
よりも滑らかな導関数を持つため、勾配消失問題が発生しにくい- 出力値が常に非負
torch.nn.Swish
- 短所:
ReLU
やSELU
よりも計算コストが高い
- 長所:
ReLU
とSELU
の利点を組み合わせたような特性を持つ- 滑らかな導関数を持つ
- 出力値が常に非負
torch.nn.Tansig
- 短所:
- 勾配消失問題が発生しやすい
- 出力値が常に非負ではない
- 長所:
- 出力値が-1から1の範囲に制限されるため、データの正規化に役立つ
カスタム活性化関数
- 短所:
- 設計と実装が複雑になる
- 長所:
- 特定のタスクに適した特性を持つ活性化関数を設計できる
- 出力値の範囲: 出力値が常に非負である必要がある場合は、
ReLU
やSELU
などの活性化関数を検討する必要があります。 - 勾配消失問題: 勾配消失問題が懸念される場合は、
SELU
やSwish
のような滑らかな導関数を持つ活性化関数を検討する必要があります。 - 計算コスト: 計算コストが制約となる場合は、
ReLU
のようなシンプルな活性化関数を検討する必要があります。 - タスク: 使用するニューラルネットワークのタスクによって、適切な代替方法が異なります。例えば、画像認識タスクでは
ReLU
がよく使用されますが、自然言語処理タスクではSELU
やSwish
がよく使用されます。