Understanding FSDP's Optimizer State Flattening: torch.distributed.fsdp.FullyShardedDataParallel.flatten_sharded_optim_state_dict()


Fully Sharded Data Parallel (FSDP) in PyTorch

FSDP is a distributed training technique in PyTorch that enables you to train large models on multiple machines by sharding (fragmenting) the model and optimizer states across these machines. This approach alleviates memory constraints on individual machines, allowing you to train models that wouldn't otherwise fit on a single device.

torch.distributed.fsdp.FullyShardedDataParallel.flatten_sharded_optim_state_dict() Function

This function plays a crucial role in the FSDP training loop, specifically during checkpointing or model saving. It's responsible for taking the sharded optimizer state dictionary (which is distributed across different machines) and flattening it into a single, contiguous dictionary. This flattened dictionary can then be easily saved to a file or used for other purposes.

Breakdown of the Function

  1. Input
    The function accepts a sharded_optim_state_dict as input. This dictionary contains the optimizer state information, but it's distributed across different machines participating in the FSDP training.
  2. Gathering
    FSDP internally gathers the sharded optimizer state from all machines involved in the training process. This ensures that all optimizer state shards are collected for flattening.
  3. Flattening
    Once gathered, the function iterates through the sharded state dictionary and combines the state tensors for each parameter across all machines. This results in a single, flattened dictionary, where each key-value pair represents a parameter name and its corresponding flattened state tensor.
  4. Output
    The function returns the flattened optimizer state dictionary. This dictionary can now be saved to a file using torch.save or used for other operations like loading the state for resuming training later.

Importance of Flattening

  • Model Serialization
    When saving a trained model using FSDP, you might also want to save the corresponding optimizer state for potential future use. Flattening the optimizer state dictionary makes it compatible with standard PyTorch serialization mechanisms.
  • Checkpointing
    Flattening is essential for checkpointing the optimizer state during training. Checkpoints capture the current state of the model and optimizer, allowing you to resume training from that point if necessary. A flattened state dictionary is easier to manage and save compared to a distributed one.


import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# ... (model definition, optimizer creation, training loop setup)

def save_checkpoint(checkpoint_filename):
  model = model.module  # Access the unwrapped model from FSDP
  optimizer_state_dict = model.optimizer.state_dict()

  # Gather and flatten optimizer state
  flattened_optim_state_dict = FSDP.flatten_sharded_optim_state_dict(optimizer_state_dict)

  # Save model and flattened optimizer state
  torch.save({
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': flattened_optim_state_dict,
      # ... other training state (optional)
  }, checkpoint_filename)

# ... (training loop)

# Save checkpoint periodically
if dist.get_rank() == 0:  # Only rank 0 saves checkpoints
  save_checkpoint("checkpoint.pth")
  1. This code defines a save_checkpoint function that takes a filename as input.
  2. Inside the function, it accesses the unwrapped model from the FSDP wrapper using model.module.
  3. It then calls model.optimizer.state_dict() to retrieve the current state dictionary of the optimizer. This dictionary is initially sharded across machines.
  4. The function then uses FSDP.flatten_sharded_optim_state_dict(optimizer_state_dict) to gather the sharded state from all machines and flatten it into a single dictionary.
  5. Finally, the function saves both the model's state dictionary and the flattened optimizer state dictionary, along with any other training state you want to capture, to the specified filename using torch.save.
  • Make sure to only save the checkpoint from rank 0 to avoid duplicate files across distributed machines.
  • This example assumes a single optimizer is used. If you have multiple optimizers, you'll need to flatten the state dictionary for each one.


    • If you have a strong understanding of FSDP internals and distributed communication, you could manually gather the sharded optimizer state from all machines and concatenate them into a single dictionary. This involves using distributed communication primitives like torch.distributed.all_gather to collect state tensors from each machine and then combining them appropriately. However, this approach requires more low-level handling and is generally less recommended due to potential complexity and error-proneness.
  1. FSDP Checkpointing Utilities (For Newer PyTorch)

    • If you're using a recent version of PyTorch with the accelerate library, you can leverage its checkpointing utilities designed for FSDP. Libraries like accelerate provide higher-level abstractions that handle FSDP-specific operations, including state saving. You can use functions like accelerate.save_state and accelerate.load_state to save and load the complete training state, including the model and FSDP-managed optimizer state. These functions internally take care of gathering and flattening the optimizer state, simplifying the process.
  2. Alternative Distributed Training Techniques

    • If FSDP's sharding mechanism doesn't suit your specific needs, you might explore other distributed training techniques in PyTorch. These techniques might have different approaches to optimizer state management. Some alternatives include:
      • Distributed Data Parallel (DDP): This approach replicates the entire model across all machines, which can be memory-intensive for large models. However, optimizer state management is simpler as it's not sharded.
      • Gradient Accumulation: This technique accumulates gradients across multiple mini-batches before updating the optimizer. It can help with memory constraints, but it requires specific coding patterns for gradient accumulation.

Choosing the Right Approach

The best approach depends on your specific context:

  • Distributed Training Technique
    If FSDP doesn't suit your needs, explore alternative distributed training techniques and their optimizer state management strategies.
  • Control
    If you need more control over the distributed communication for gathering optimizer state, manual gathering might be an option (but with higher complexity).
  • Complexity
    If simplicity is your priority, consider using FSDP checkpointing utilities from accelerate (if applicable).