【初心者向け】PyTorch Tensor Parallelismで分散学習を始めるための第一歩!「torch.distributed.tensor.parallel.parallelize_module()」の使い方


torch.distributed.tensor.parallel.parallelize_module() は、PyTorchにおける Tensor Parallelism 機能の一部であり、大規模なモデルの効率的な訓練を可能にする強力なツールです。この関数は、モデルのモジュールまたはサブモジュールを、複数のデバイス間で分散させるために使用されます。これにより、メモリ使用量を削減し、モデルの訓練にかかる時間を短縮することができます。

詳細解説

parallelize_module() 関数は、以下の引数を取ります。

  • parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]])
    モジュールの分散方法を指定するプラン
  • device_mesh (DeviceMesh)
    デバイス間のメッシュトポロジーを記述するオブジェクト
  • module (nn.Module)
    分散させる対象となるモデルのモジュールまたはサブモジュール

parallelize_plan 引数は、以下のいずれかになります。

  • 辞書: 各モジュールに対する分散方法を個別に指定します。
  • ParallelStyle オブジェクト: 分散方法を統一的に指定します。

import torch
import torch.distributed.tensor.parallel as tp

# モデルを定義
model = nn.Sequential(
    nn.Linear(100, 1000),
    nn.ReLU(),
    nn.Linear(1000, 10)
)

# デバイスメッシュを定義
device_mesh = tp.MeshSpec(devices=[torch.device('cuda:0'), torch.device('cuda:1')])

# 分散プランを定義
parallelize_plan = tp.RowwiseParallel()

# モデルを分散させる
tp.parallelize_module(model, device_mesh, parallelize_plan)

上記の例では、model モデルを 2 つの GPU に分散させています。RowwiseParallel 分散プランは、各行を個別の GPU に割り当てる方法を指定します。

  • Tensor Parallelism は、まだ発展途上の機能であり、今後変更される可能性があります。
  • Tensor Parallelism を使用するには、事前に torch.distributed モジュールを初期化しておく必要があります。
  • Tensor Parallelism は、大規模なモデルの訓練に特化した機能であり、小規模なモデルには適していない場合があります。


例 1: 単一のパラレルスタイルでモデル全体を分散させる

import torch
import torch.distributed.tensor.parallel as tp

# モデルを定義
model = nn.Sequential(
    nn.Linear(100, 1000),
    nn.ReLU(),
    nn.Linear(1000, 10)
)

# デバイスメッシュを定義
device_mesh = tp.MeshSpec(devices=[torch.device('cuda:0'), torch.device('cuda:1')])

# 分散プランを定義
parallelize_plan = tp.RowwiseParallel()

# モデルを分散させる
tp_model = tp.parallelize_module(model, device_mesh, parallelize_plan)

# データを分散させる
input = torch.randn(10, 100)
tp_input = tp.scatter_kwargs(input, device_mesh)

# モデルを実行
output = tp_model(tp_input)

# 出力を集約
output = tp.gather(output)

この例では、RowwiseParallel 分散プランを使用して、モデル全体を 2 つの GPU に分散させています。すべてのモジュールは、行方向に分割されます。

例 2: 異なるモジュールに異なるパラレルスタイルを適用する

import torch
import torch.distributed.tensor.parallel as tp

# モデルを定義
model = nn.Sequential(
    nn.Linear(100, 1000),
    nn.ReLU(),
    nn.Linear(1000, 10)
)

# デバイスメッシュを定義
device_mesh = tp.MeshSpec(devices=[torch.device('cuda:0'), torch.device('cuda:1')])

# 分散プランを定義
parallelize_plan = {
    "linear1": tp.RowwiseParallel(),
    "relu": tp.NoParallel(),
    "linear2": tp.ColumnwiseParallel()
}

# モデルを分散させる
tp_model = tp.parallelize_module(model, device_mesh, parallelize_plan)

# データを分散させる
input = torch.randn(10, 100)
tp_input = tp.scatter_kwargs(input, device_mesh)

# モデルを実行
output = tp_model(tp_input)

# 出力を集約
output = tp.gather(output)

この例では、linear1 モジュールは行方向に分割し、relu モジュールは分散させず、linear2 モジュールは列方向に分割しています。

例 3: カスタム DDP モジュールを使用する

import torch
import torch.distributed.tensor.parallel as tp
import my_custom_ddp_module

# モデルを定義
model = nn.Sequential(
    my_custom_ddp_module.MyCustomDDPModule(),
    nn.Linear(1000, 10)
)

# デバイスメッシュを定義
device_mesh = tp.MeshSpec(devices=[torch.device('cuda:0'), torch.device('cuda:1')])

# 分散プランを定義
parallelize_plan = {
    "my_custom_ddp_module": tp.NoParallel(),
    "linear2": tp.ColumnwiseParallel()
}

# モデルを分散させる
tp_model = tp.parallelize_module(model, device_mesh, parallelize_plan)

# データを分散させる
input = torch.randn(10, 100)
tp_input = tp.scatter_kwargs(input, device_mesh)

# モデルを実行
output = tp_model(tp_input)

# 出力を集約
output = tp.gather(output)

この例では、my_custom_ddp_module モジュールは分散させず、linear2 モジュールは列方向に分割しています。my_custom_ddp_module は、torch.distributed.nn モジュールを使用してカスタム DDP ロジックを実装したモジュールです。



データ並列化 (Data Parallelism)

データ並列化は、最も基本的な並列化手法の一つであり、モデルの入力と出力を複数の GPU に分割して処理します。これは、比較的単純で使いやすい方法ですが、メモリ使用量と通信コストが高くなる可能性があります。

import torch
import torch.nn.parallel as nn

# モデルを定義
model = nn.Sequential(
    nn.Linear(100, 1000),
    nn.ReLU(),
    nn.Linear(1000, 10)
)

# データを分割
input = torch.randn(10, 100)
device_ids = [0, 1]
input_split = nn.DataParallel(model, device_ids=device_ids)(input)

モデル並列化 (Model Parallelism)

モデル並列化は、モデルを複数のサブモジュールに分割し、各サブモジュールを異なる GPU で実行する方法です。これは、データ並列化よりもメモリ使用量と通信コストを低減できますが、実装がより複雑になります。

import torch
import torch.distributed as dist
import torch.nn.parallel as nn

# モデルを定義
model = nn.Sequential(
    nn.Linear(100, 1000),
    nn.ReLU(),
    nn.Linear(1000, 10)
)

# デバイスを初期化
dist.init_process_group(backend='nccl')

# モデルを分割
device_ids = [0, 1]
model_parallel = nn.DistributedDataParallel(model, device_ids=device_ids)

# データを分割
input = torch.randn(10, 100)
input_split = model_parallel(input)

ハイブリッド並列化 (Hybrid Parallelism)

ハイブリッド並列化は、データ並列化とモデル並列化を組み合わせた手法です。これは、両方の長所を活かし、メモリ使用量、通信コスト、および訓練速度のバランスを取ることができます。

import torch
import torch.distributed as dist
import torch.nn.parallel as nn

# モデルを定義
model = nn.Sequential(
    nn.Linear(100, 1000),
    nn.ReLU(),
    nn.Linear(1000, 10)
)

# デバイスを初期化
dist.init_process_group(backend='nccl')

# モデルを分割
device_ids = [0, 1]
model_parallel = nn.DistributedDataParallel(model, device_ids=device_ids)

# データを分割
input = torch.randn(10, 100)
input_split = model_parallel(input)

Tensor Parallelism以外にも、大規模なモデルの訓練を効率化するためのライブラリがいくつか存在します。

最適な方法の選択

最適な方法は、モデルのサイズ、ハードウェア、および訓練要件によって異なります。一般的には、以下の点を考慮する必要があります。

  • 訓練要件
    訓練速度、メモリ使用量、通信コストのいずれを優先するかによって、最適な方法が異なります。
  • ハードウェア
    利用可能な GPU の数とメモリ容量によって、使用できる並列化手法が制限されます。
  • モデルのサイズ
    モデルが大きければ大きいほど、より高度な並列化手法が必要になります。