PyTorchのGPU活用術:torch.mpsからCUDA、MLXまで徹底比較

2025-05-26

torch.mpsとは何か?

  • PyTorchとの統合: torch.mpsは、PyTorchの計算グラフやプリミティブ(基本的な演算)をMPSフレームワークにマッピングし、Metal GPU上で効率的に実行できるようにします。これにより、macOSデバイス上での機械学習モデルのトレーニングや推論のパフォーマンスが向上します。
  • Metal Performance Shaders (MPS): Appleが提供する、GPU上で高性能な計算を行うためのフレームワークです。機械学習やグラフィックス処理に最適化されたカーネル(GPU上で実行されるプログラム)が含まれています。

主な特徴とメリット

  1. Apple Siliconに最適化: Apple Silicon(M1, M2, M3チップなど)を搭載したMacのGPUの性能を最大限に活用できます。これらのチップは、CPU、GPU、ニューラルエンジン、ユニファイドメモリを統合したSoC(System on a Chip)であり、PyTorchがこれらのハードウェアの利点を活かせるように設計されています。
  2. GPUアクセラレーション: CPUのみで実行するよりも、はるかに高速にモデルのトレーニングや推論を行うことができます。特に、大規模なモデルやバッチサイズでの効果が顕著です。
  3. ユニファイドメモリ: Apple Siliconのユニファイドメモリアーキテクチャにより、CPUとGPUが同じメモリを共有するため、データ転送の遅延が少なくなり、エンドツーエンドのパフォーマンスが向上します。
  4. 既存のPyTorchスクリプトとの互換性: 既存のPyTorchコードを少し修正するだけで、MPSデバイスを利用できます。通常は、テンソルやモデルを"mps"デバイスに移動するだけです。
    import torch
    
    if torch.backends.mps.is_available():
        mps_device = torch.device("mps")
        x = torch.ones(5, device=mps_device) # MPSデバイス上にテンソルを作成
        model = YourFavoriteNet().to(mps_device) # モデルをMPSデバイスに移動
        # その後の計算はGPUで行われる
        y = x * 2
        pred = model(x)
    else:
        print("MPS device not found.")
    
  5. 開発環境としての利便性: Macユーザーにとって、クラウドベースの開発環境や追加のGPUハードウェアなしに、ローカルで機械学習のプロトタイピングやファインチューニングを行うことが可能になります。

torch.mpsはまだ比較的新しいバックエンドであり、いくつかの制限や注意点があります。

  1. オペレーターのサポート状況: すべてのPyTorchオペレーターが完全にMPSに移植されているわけではありません。未サポートのオペレーターが含まれる場合、パフォーマンスが低下したり、エラーが発生したりする可能性があります。
    • PYTORCH_ENABLE_MPS_FALLBACK=1環境変数を設定することで、未サポートのオペレーターをCPUにフォールバックさせることができますが、この場合、CPUとMPSデバイス間のデータ転送が発生し、パフォーマンスが低下します。
    • 特に、一部の複雑なオペレーター(例: LSTMの特定の引数設定)でバグや非互換性が報告された時期もありましたが、PyTorchチームによる継続的な開発により改善が進んでいます。
  2. 数値精度: MPSはMetalの「fast math」を利用しているため、一部の計算でCUDAやCPUと異なる数値結果をもたらす可能性があります。厳密な数値精度が求められるアプリケーションでは注意が必要です。
  3. メモリ使用量: 特にユニファイドメモリが少ないMacデバイスの場合、大規模なモデルやバッチサイズではメモリ不足(Out Of Memory, OOM)が発生する可能性があります。メモリ管理のためのAPI(例: torch.mps.set_per_process_memory_fraction)も提供されていますが、注意が必要です。
  4. 単一デバイスのサポート: 現在のところ、MPSは単一のGPUデバイスのみをサポートしており、複数のMPS対応GPUを搭載したマシンでの分散学習はサポートされていません。
  5. NDArrayサイズの上限: PyTorchのMPSバックエンドは、232を超えるNDArrayサイズをサポートしていません。


torch.mpsでよくあるエラーとトラブルシューティング

NotImplementedError: Could not run 'aten::...' with arguments from the 'MPS' backend. (未実装オペレーター)

  • トラブルシューティング:
    1. PyTorchのバージョンアップ: PyTorchは継続的にMPSのサポートを改善しています。最新の安定版またはナイトリービルドにアップデートすることで、未サポートだったオペレーターが追加されている可能性があります。
      pip install --upgrade torch torchvision torchaudio
      
      または、ナイトリービルド(開発版)を試すことも有効です。
      pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
      
      (注: --index-urlはCPU版のURLですが、MPS対応版が同梱されていることが多いです。公式ドキュメントで最新のインストール方法を確認してください。)
    2. CPUフォールバックの有効化: 環境変数PYTORCH_ENABLE_MPS_FALLBACK=1を設定することで、MPSがサポートしていないオペレーターを自動的にCPUにフォールバックさせることができます。
      # ターミナルで実行
      export PYTORCH_ENABLE_MPS_FALLBACK=1
      # またはPythonコード内で
      import os
      os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
      
      注意点: フォールバックすると、CPUとMPSデバイス間でデータ転送が発生するため、パフォーマンスが低下する可能性があります。また、GPUとCPUで計算が分散されることで、デバッグが難しくなることもあります。
    3. モデルの変更: 可能であれば、未サポートのオペレーターを使用しないようにモデルアーキテクチャを変更することを検討します。これは通常、最終手段です。
  • エラーの原因: 最も一般的なエラーの一つです。PyTorchのすべてのオペレーター(演算)がMPSバックエンドで完全にサポートされているわけではありません。モデルが未サポートのオペレーターを使用しようとすると、このエラーが発生します。特に、比較的新しいモデルアーキテクチャやあまり使われない演算で起こりやすいです。

計算結果の不一致 (Correctness issues / Silent correctness issues)

  • トラブルシューティング:
    1. PyTorchのバージョンアップ: 最も簡単な解決策は、PyTorchを最新版にアップデートすることです。バグは継続的に修正されています。
    2. フォールバックの利用: 不正確な結果を返すオペレーターが特定できる場合は、その部分だけCPUにフォールバックさせることを検討します。
    3. 精度設定の確認: 特にモデルのトレーニング時や推論時、torch.float32を使用しているか確認してください。場合によっては、torch.float16(半精度)の使用によって問題が発生することもあります。混合精度トレーニング(torch.autocast(device_type="mps"))を使用している場合は、その設定が適切か確認します。
    4. 再現性の確認: シード値を固定し、CPUとMPSの両方で同じ入力データに対して計算を行い、結果が一致しないことを確認します。結果のわずかな違いは、浮動小数点演算の性質上、許容される場合があります。
  • エラーの原因: MPSバックエンドで計算された結果が、CPUやCUDA(NVIDIA GPU)で計算された結果と異なる場合があります。これは、主に以下の理由によります。
    • 数値精度: MPSはMetalの「fast math」を利用しているため、一部の計算でCUDAやCPUと異なる数値結果をもたらす可能性があります。特に、浮動小数点演算の丸め誤差や実装の違いが影響します。
    • 特定のオペレーターの実装バグ: 特定のオペレーター(例: torch.where()torch.nonzero()、一部の分散関連演算)で、MPSの実装にバグがあり、間違った結果を返すことが報告された事例があります。

Out Of Memory (OOM) エラー / メモリ不足

  • トラブルシューティング:
    1. バッチサイズの削減: モデルの入力データのバッチサイズを小さくします。これは最も一般的な対策です。
    2. モデルサイズの縮小: 可能であれば、より小さなモデル(パラメータ数が少ないモデル)を使用することを検討します。
    3. 不要な変数の削除: 計算後に不要になったテンソルや変数をdelで削除し、torch.mps.empty_cache()を呼び出してキャッシュをクリアします。
      del some_tensor
      torch.mps.empty_cache() # MPSの内部キャッシュを解放
      
    4. 混合精度トレーニングの利用: torch.amp.autocast(device_type="mps")を使用して混合精度トレーニングを有効にすると、メモリ使用量を大幅に削減できる場合があります。
      with torch.amp.autocast(device_type="mps"):
          output = model(input)
          loss = criterion(output, target)
      scaler.scale(loss).backward()
      scaler.step(optimizer)
      scaler.update()
      
    5. メモリ使用量の上限設定: 特定のPyTorchバージョンでは、MPSデバイスのメモリ使用量の上限を設定できる機能が提供されている場合があります。
      # 例: MPSデバイスのメモリ使用量を全体の50%に制限
      # このAPIは変更される可能性があるため、公式ドキュメントで確認してください
      # torch.mps.set_per_process_memory_fraction(0.5)
      
  • エラーの原因: MacのGPUメモリ(ユニファイドメモリ)は、システムメモリと共有されています。大規模なモデルや大きなバッチサイズを使用すると、メモリを使い果たし、OOMエラーが発生する可能性があります。

パフォーマンスが期待通りに出ない / CPUよりも遅い

  • トラブルシューティング:
    1. PYTORCH_ENABLE_MPS_FALLBACK=1でのログ確認: フォールバックを有効にして実行し、どのオペレーターがCPUにフォールバックしているかを示す警告メッセージを確認します。これにより、パフォーマンスボトルネックの特定に役立ちます。
    2. プロファイリング: torch.profilerなどのツールを使用して、どこで時間がかかっているかを詳細に分析します。CPUとMPSのどちらで多くの時間が費やされているかを確認できます。
    3. テンソルのデバイス固定: テンソルが常にMPSデバイス上にあることを確認し、不要なデバイス間の移動を避けます。
      # テンソルをMPSデバイスに作成
      x = torch.randn(100, 100, device="mps")
      # モデルもMPSデバイスに移動
      model.to("mps")
      
    4. PyTorchのバージョンアップ: 最新版ではパフォーマンス改善が含まれていることがあります。
  • エラーの原因:
    1. 未サポートオペレーターによるCPUフォールバック: 前述の通り、多くのオペレーターがCPUにフォールバックしている場合、GPUの恩恵が失われ、かえって遅くなることがあります。
    2. 小さなモデルやバッチサイズ: GPUは並列計算に優れていますが、オーバーヘッドも存在します。非常に小さなモデルやバッチサイズの場合、CPUでの計算の方が速いことがあります。
    3. データ転送のオーバーヘッド: CPUとGPU間のデータ転送はコストがかかります。頻繁にデバイス間でテンソルを移動している場合、パフォーマンスが低下します。
    4. MPSバックエンドの成熟度: CUDAに比べて、MPSバックエンドはまだ発展途上であり、すべてのシナリオでCUDAと同等の最適化が施されているわけではありません。

PyTorchのインストールと環境設定の問題

  • トラブルシューティング:
    1. MPSの可用性確認: コードの冒頭でtorch.backends.mps.is_available()Trueを返すか確認します。
      import torch
      if torch.backends.mps.is_available():
          print("MPS is available!")
          device = torch.device("mps")
      else:
          print("MPS is not available.")
          device = torch.device("cpu")
      
    2. macOSのバージョン: torch.mpsはmacOS 12.3以降(Monterey)が必要です。macOSのバージョンを確認し、必要であればアップデートします。
    3. PyTorchのインストール方法: 公式サイトの指示に従って、MPS対応のPyTorchをインストールしているか確認します。通常、pip install torch torchvision torchaudioで最新版をインストールすればMPSも含まれていますが、特定の古いバージョンでは利用できない場合があります。
    4. Python環境: 仮想環境(venvcondaなど)を使用しているか確認し、意図しないPythonのバージョンやパッケージが使用されていないことを確認します。
  • エラーの原因: torch.mpsが利用できない場合、PyTorchのインストールが正しく行われていないか、macOSのバージョンが古すぎる可能性があります。
  • CPUでテスト: MPSでエラーが発生した場合、同じコードをCPU(device="cpu")で実行してみて、問題がMPS固有のものなのか、それともコード全体の問題なのかを切り分けます。
  • 簡略化された再現コード: 問題が発生した場合、できるだけ簡単なコードでその問題を再現できるか試します。これにより、問題の切り分けが容易になります。
  • PyTorchのGitHub Issueを検索: 経験した問題が既知のバグである可能性があります。PyTorchのGitHubリポジトリのIssueセクションで検索すると、解決策や進捗状況が見つかることがあります。
  • エラーメッセージの確認: エラーメッセージは、問題の根源を示す重要な情報を含んでいます。特に、NotImplementedErrorの場合は、どのオペレーターが問題なのかが示されます。


MPSデバイスの利用可能性の確認とテンソルの作成

まず、MPSデバイスが利用可能かどうかを確認し、利用可能な場合はテンソルをMPSデバイス上に作成します。

import torch

# MPSデバイスの利用可能性を確認
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    print("MPSデバイスが利用可能です。")
else:
    print("MPSデバイスは利用できません。CPUを使用します。")
    mps_device = torch.device("cpu")

print(f"現在のデバイス: {mps_device}")

# MPSデバイス上にテンソルを作成
x_mps = torch.randn(3, 4, device=mps_device)
print(f"MPS上のテンソルx_mps:\n{x_mps}")
print(f"x_mpsのデバイス: {x_mps.device}")

# CPU上にテンソルを作成し、MPSデバイスに移動
x_cpu = torch.zeros(2, 3)
print(f"\nCPU上のテンソルx_cpu:\n{x_cpu}")
x_mps_moved = x_cpu.to(mps_device)
print(f"MPSに移動したテンソルx_mps_moved:\n{x_mps_moved}")
print(f"x_mps_movedのデバイス: {x_mps_moved.device}")

# MPSデバイス上のテンソルでの演算
y_mps = x_mps * 2 + 1
print(f"\nMPS上のテンソルでの演算結果y_mps:\n{y_mps}")
print(f"y_mpsのデバイス: {y_mps.device}")

# CPUにテンソルを戻す
y_cpu_back = y_mps.to("cpu")
print(f"\nCPUに戻したテンソルy_cpu_back:\n{y_cpu_back}")
print(f"y_cpu_backのデバイス: {y_cpu_back.device}")

解説:

  • MPSデバイス上のテンソルに対する演算は、自動的にGPU上で実行されます。
  • torch.randn(..., device=mps_device)tensor.to(mps_device)を使って、テンソルをMPSデバイス上に配置または移動します。
  • torch.device("mps")でMPSデバイスオブジェクトを作成します。
  • torch.backends.mps.is_available()でMPSが利用可能かチェックします。

ニューラルネットワークモデルのMPSデバイスへの移動とトレーニング

シンプルなニューラルネットワークを定義し、それをMPSデバイスに移動してトレーニングする例です。

import torch
import torch.nn as nn
import torch.optim as optim

# デバイスの準備
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPSデバイスが利用可能です。トレーニングにMPSを使用します。")
else:
    device = torch.device("cpu")
    print("MPSデバイスは利用できません。トレーニングにCPUを使用します。")

# 非常にシンプルなニューラルネットワークの定義
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# モデルを定義し、デバイスに移動
model = SimpleNet().to(device)
print(f"\nモデルが配置されたデバイス: {next(model.parameters()).device}")

# 損失関数とオプティマイザ
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# ダミーデータの生成 (device=device を忘れずに)
# 入力: 100サンプル、特徴量10
# 出力: 100サンプル、出力1
X_train = torch.randn(100, 10, device=device)
y_train = torch.randn(100, 1, device=device) # 正解データ

# トレーニングループ
num_epochs = 100
print("\nトレーニング開始...")
for epoch in range(num_epochs):
    # フォワードパス
    outputs = model(X_train)
    loss = criterion(outputs, y_train)

    # バックワードパスと最適化
    optimizer.zero_grad() # 勾配をゼロクリア
    loss.backward()       # 逆伝播
    optimizer.step()      # パラメータ更新

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("トレーニング終了。")

# 推論 (トレーニング済みのモデルで)
# 新しい入力データもMPSデバイスに配置
X_test = torch.randn(5, 10, device=device)
model.eval() # 評価モードに設定 (DropoutやBatchNormが無効になる)
with torch.no_grad(): # 勾配計算を無効化
    predictions = model(X_test)
print(f"\n推論結果 (MPS上):\n{predictions}")

解説:

  • トレーニングループは通常のPyTorchのそれと変わりません。
  • 入力データX_trainy_traindevice=deviceを指定してMPSデバイス上に作成します。データとモデルは同じデバイス上に存在する必要があります。
  • model.to(device)を使って、モデルのすべてのパラメータをMPSデバイスに移動します。一度移動すれば、その後のフォワードパスやバックワードパスはGPU上で行われます。

混合精度トレーニング (torch.amp) の利用

torch.amp (Automatic Mixed Precision) は、モデルのメモリ使用量を削減し、トレーニング速度を向上させるために、半精度浮動小数点(float16)と単精度浮動小数点(float32)を組み合わせて使用する技術です。MPSでも利用可能です。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler # MPSでもこのScalerを使用

# デバイスの準備
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPSデバイスが利用可能です。混合精度トレーニングにMPSを使用します。")
else:
    device = torch.device("cpu")
    print("MPSデバイスは利用できません。CPUを使用します。")

# シンプルなネットワーク (Example 2と同じ)
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

model = SimpleNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# GradScalerを初期化 (MPSでもCUDAのScalerを使用)
scaler = GradScaler()

X_train = torch.randn(100, 10, device=device)
y_train = torch.randn(100, 1, device=device)

num_epochs = 100
print("\n混合精度トレーニング開始...")
for epoch in range(num_epochs):
    optimizer.zero_grad()

    # autocastコンテキスト内でフォワードパスを実行
    with torch.autocast(device_type="mps"): # device_typeを"mps"に指定
        outputs = model(X_train)
        loss = criterion(outputs, y_train)

    # スケーリングされた損失でバックワードパス
    scaler.scale(loss).backward()

    # オプティマイザーステップの実行
    scaler.step(optimizer)

    # スケーラーの更新
    scaler.update()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("混合精度トレーニング終了。")

解説:

  • scaler.scale(loss).backward()scaler.step(optimizer)scaler.update() は、勾配のアンダフローを防ぎながら、混合精度でトレーニングを行うための標準的なAMPワークフローです。
  • with torch.autocast(device_type="mps"): ブロック内でフォワードパスを実行することで、互換性のある演算が自動的に半精度で実行されます。
  • from torch.cuda.amp import GradScaler を使用しますが、GradScaler はCUDAだけでなくMPSでも動作します。

長時間トレーニングを行う場合や、大規模なデータセットを扱う場合にメモリを解放するために使用します。

import torch

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    print("MPSデバイスが利用可能です。")
else:
    print("MPSデバイスは利用できません。")
    exit() # MPSがない場合はここで終了

# ダミーデータを作成 (MPSデバイス上)
a = torch.randn(1000, 1000, device=mps_device)
b = torch.randn(1000, 1000, device=mps_device)
c = a @ b # 行列積

print(f"テンソルcのデバイス: {c.device}")

# 計算完了後、不要になったテンソルを削除
del a
del b
del c

# MPSキャッシュをクリアし、メモリを解放
# これにより、システムが利用可能なメモリが増える可能性がある
torch.mps.empty_cache()
print("MPSキャッシュをクリアしました。")

# 再びテンソルを作成して計算
d = torch.ones(500, 500, device=mps_device)
print(f"テンソルdのデバイス: {d.device}")

解説:

  • delでPythonの参照を削除した後、torch.mps.empty_cache()を呼び出すことで、MPSが内部的に保持しているキャッシュを解放し、OSにメモリを返します。これは、特に長いトレーニングプロセスで、メモリ使用量が時間とともに増加するのを防ぐのに役立ちます。


ここでは、torch.mpsの代替方法や、それに代わる(あるいは補完する)アプローチについて説明します。

CPU (Central Processing Unit)

最も基本的な代替手段であり、どのコンピューターでも利用可能です。

  • コード例:
    import torch
    
    device = torch.device("cpu") # 明示的にCPUデバイスを指定
    
    x = torch.randn(10, 10, device=device)
    model = YourModel().to(device)
    
    print(f"現在のデバイス: {device}")
    print(f"テンソルxのデバイス: {x.device}")
    print(f"モデルのデバイス: {next(model.parameters()).device}")
    
  • 利用場面:
    • 小規模なモデルのプロトタイピングやデバッグ。
    • MPSで未サポートのオペレーターが含まれる場合のフォールバック。
    • CPUの計算能力で十分なバッチ処理や推論。
  • 特徴:
    • 高い互換性: どのPyTorchオペレーターもCPU上で動作します。
    • どこでも動作: 特殊なハードウェアやドライバは不要です。
    • 低速: 大規模なモデルのトレーニングや推論には向いていません。特にディープラーニングにおいては、GPUの並列処理能力には遠く及びません。

CUDA (Compute Unified Device Architecture)

NVIDIA社製のGPUで利用される並列コンピューティングプラットフォームです。ディープラーニングの分野で最も広く使われており、PyTorchのGPU高速化のデファクトスタンダードです。

  • コード例:
    import torch
    
    if torch.cuda.is_available():
        device = torch.device("cuda") # CUDAデバイスを指定
        print("CUDAデバイスが利用可能です。")
    else:
        device = torch.device("cpu")
        print("CUDAデバイスは利用できません。CPUを使用します。")
    
    print(f"現在のデバイス: {device}")
    
    x = torch.randn(10, 10, device=device)
    model = YourModel().to(device)
    
    # CUDA特有のメモリ管理 (mps.empty_cache()に相当)
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    
    注意: Mac上でCUDAを使用する場合、eGPU(外部GPU)としてNVIDIA製GPUを接続するか、Windows/Linuxマシンを用意する必要があります。Apple Silicon Macでは直接CUDAは利用できません。
  • 利用場面:
    • 大規模なモデルのトレーニング(特に大規模言語モデルや画像生成モデル)。
    • 本番環境での高速な推論。
    • 最新の研究成果を試す際。
  • 特徴:
    • 圧倒的な高速性: NVIDIA GPUの並列処理能力を最大限に引き出します。
    • 広範なライブラリサポート: ほとんどのディープラーニングライブラリやフレームワークがCUDAに最適化されています。
    • 成熟したエコシステム: 長年の開発とコミュニティのサポートにより、安定性と機能が非常に充実しています。
    • ハードウェア要件: NVIDIA製GPUが必要です。

MLX (Apple Machine Learning eXchange)

Appleが開発・公開している、Apple Siliconに最適化された機械学習フレームワークです。PyTorchとは異なる独立したフレームワークですが、NumPyに似たAPIを提供し、MPSと同じくMetalを利用して高速な計算を実現します。

  • コード例 (MLX):
    import mlx.core as mx
    import mlx.nn as nn
    import mlx.optimizers as optim
    
    # MLXではデバイス管理が暗黙的
    # デフォルトでMPS (Metal) を使用しようとします
    
    # テンソルの作成
    x = mx.random.normal((3, 4))
    print(f"MLX上のテンソルx:\n{x}")
    
    # シンプルなモデルの定義
    class SimpleMLXNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(10, 5)
            self.relu = nn.relu
            self.fc2 = nn.Linear(5, 1)
    
        def __call__(self, x):
            return self.fc2(self.relu(self.fc1(x)))
    
    model = SimpleMLXNet()
    
    # 損失関数とオプティマイザ
    loss_fn = lambda model, X, y: mx.mean((model(X) - y)**2)
    optimizer = optim.SGD(learning_rate=0.01)
    
    # トレーニング関数 (MLXの関数変換を利用)
    @mx.compile
    def train_step(model, X, y, optimizer):
        loss, grads = mx.value_and_grad(model, loss_fn)(model, X, y)
        optimizer.update(model, grads)
        return loss
    
    # ダミーデータ
    X_train = mx.random.normal((100, 10))
    y_train = mx.random.normal((100, 1))
    
    # トレーニングループ
    print("\nMLXトレーニング開始...")
    for epoch in range(100):
        loss = train_step(model, X_train, y_train, optimizer)
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
    print("MLXトレーニング終了。")
    
    注意: MLXはPyTorchとは異なるAPIを持つため、既存のPyTorchコードをMLXに移行するには書き換えが必要です。
  • 利用場面:
    • Apple Silicon環境での機械学習のプロトタイピングや研究開発。
    • PyTorchの複雑さに悩まされず、よりシンプルにGPU計算を行いたい場合。
    • torch.mpsでパフォーマンスが出ない、または未サポートのオペレーターが多い場合に試す価値があります。
  • 特徴:
    • Apple Siliconに特化: MPSと同様に、Apple Siliconのユニファイドメモリと高性能なGPUを最大限に活用します。
    • 軽量でシンプル: Python APIはNumPyとPyTorchからインスパイアされており、使いやすい設計です。
    • 動的グラフ: PyTorchのように動的な計算グラフをサポートします。
    • PyTorchとの比較: 特定のベンチマークではPyTorch MPSよりも高速な場合もあれば、遅い場合もあります。まだ発展途上です。

これらはトレーニング時ではなく、推論(モデルのデプロイ)に焦点を当てた代替手段です。トレーニング済みのPyTorchモデルをこれらの形式に変換し、最適化されたランタイムで実行します。

  • Core ML:

    • 特徴: Appleデバイス(iPhone, iPad, Mac)上で機械学習モデルを効率的に実行するためのApple独自のフレームワークです。モデルをCore ML形式に変換すると、CPUやGPU、Neural Engine(NPU)を自動的に活用して高速な推論が可能です。
    • 利用場面: iOS/macOSアプリケーションへのモデル組み込み。
    • PyTorchからの変換: coremltoolsライブラリを使用してPyTorchモデルをCore ML形式に変換できます。
    • 注意: 変換プロセスはモデルの複雑さやオペレーターのサポート状況によって異なります。
  • ONNX (Open Neural Network Exchange) Runtime:

    • 特徴: 異なるフレームワーク(PyTorch, TensorFlowなど)でトレーニングされたモデルを共通の中間表現であるONNX形式に変換し、様々なハードウェア(CPU, GPU, 専用アクセラレータ)で実行できるランタイムです。プラットフォーム非依存性が高いです。
    • 利用場面: クラウド、エッジデバイス、異なるOS(Windows, Linux, macOS)など、多様な環境でのモデルデプロイ。
    • コード例 (PyTorchモデルをONNXに変換):
      import torch
      import torch.nn as nn
      
      class SimpleNet(nn.Module):
          def __init__(self):
              super(SimpleNet, self).__init__()
              self.fc1 = nn.Linear(10, 5)
              self.relu = nn.ReLU()
              self.fc2 = nn.Linear(5, 1)
      
          def forward(self, x):
              return self.fc2(self.relu(self.fc1(x)))
      
      model = SimpleNet()
      dummy_input = torch.randn(1, 10) # ダミー入力
      
      # モデルをONNX形式でエクスポート
      torch.onnx.export(model,
                        dummy_input,
                        "simple_net.onnx", # 出力ファイル名
                        opset_version=17,  # ONNXオペレーションセットバージョン
                        input_names=['input'],
                        output_names=['output'],
                        dynamic_axes={'input': {0: 'batch_size'},
                                      'output': {0: 'batch_size'}})
      print("モデルを simple_net.onnx にエクスポートしました。")
      
      # ONNX Runtimeでの実行 (例)
      import onnxruntime as ort
      import numpy as np
      
      ort_session = ort.InferenceSession("simple_net.onnx")
      # モデルの入力形式に合わせてNumPy配列を用意
      ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
      ort_outputs = ort_session.run(None, ort_inputs)
      print(f"ONNX Runtimeでの推論結果:\n{ort_outputs[0]}")
      

torch.mpsはApple Siliconユーザーにとって非常に強力な機能ですが、それが常に唯一の選択肢であるとは限りません。

  • デプロイ: モデルを最終的なアプリケーションに組み込む際には、ONNX RuntimeやCore MLのような推論に特化したランタイムが、プラットフォーム間の互換性やパフォーマンスの最適化において優れた選択肢となります。
  • Apple Siliconに特化した最適化: torch.mpsで問題がある場合や、より深いレベルでの最適化を求める場合は、MLXを検討する価値があります。
  • 大規模なトレーニング: NVIDIA GPUを搭載したクラウド環境や専用のマシンがある場合は、CUDAが最高のパフォーマンスを提供します。
  • 開発・実験: Mac上で作業する場合、まずはtorch.mpsを試すのが最も簡単で効率的です。