Understanding FSDP's Optimizer State Flattening: torch.distributed.fsdp.FullyShardedDataParallel.flatten_sharded_optim_state_dict()
Fully Sharded Data Parallel (FSDP) in PyTorch
FSDP is a distributed training technique in PyTorch that enables you to train large models on multiple machines by sharding (fragmenting) the model and optimizer states across these machines. This approach alleviates memory constraints on individual machines, allowing you to train models that wouldn't otherwise fit on a single device.
torch.distributed.fsdp.FullyShardedDataParallel.flatten_sharded_optim_state_dict()
Function
This function plays a crucial role in the FSDP training loop, specifically during checkpointing or model saving. It's responsible for taking the sharded optimizer state dictionary (which is distributed across different machines) and flattening it into a single, contiguous dictionary. This flattened dictionary can then be easily saved to a file or used for other purposes.
Breakdown of the Function
- Input
The function accepts asharded_optim_state_dict
as input. This dictionary contains the optimizer state information, but it's distributed across different machines participating in the FSDP training. - Gathering
FSDP internally gathers the sharded optimizer state from all machines involved in the training process. This ensures that all optimizer state shards are collected for flattening. - Flattening
Once gathered, the function iterates through the sharded state dictionary and combines the state tensors for each parameter across all machines. This results in a single, flattened dictionary, where each key-value pair represents a parameter name and its corresponding flattened state tensor. - Output
The function returns the flattened optimizer state dictionary. This dictionary can now be saved to a file usingtorch.save
or used for other operations like loading the state for resuming training later.
Importance of Flattening
- Model Serialization
When saving a trained model using FSDP, you might also want to save the corresponding optimizer state for potential future use. Flattening the optimizer state dictionary makes it compatible with standard PyTorch serialization mechanisms. - Checkpointing
Flattening is essential for checkpointing the optimizer state during training. Checkpoints capture the current state of the model and optimizer, allowing you to resume training from that point if necessary. A flattened state dictionary is easier to manage and save compared to a distributed one.
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# ... (model definition, optimizer creation, training loop setup)
def save_checkpoint(checkpoint_filename):
model = model.module # Access the unwrapped model from FSDP
optimizer_state_dict = model.optimizer.state_dict()
# Gather and flatten optimizer state
flattened_optim_state_dict = FSDP.flatten_sharded_optim_state_dict(optimizer_state_dict)
# Save model and flattened optimizer state
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': flattened_optim_state_dict,
# ... other training state (optional)
}, checkpoint_filename)
# ... (training loop)
# Save checkpoint periodically
if dist.get_rank() == 0: # Only rank 0 saves checkpoints
save_checkpoint("checkpoint.pth")
- This code defines a
save_checkpoint
function that takes a filename as input. - Inside the function, it accesses the unwrapped model from the FSDP wrapper using
model.module
. - It then calls
model.optimizer.state_dict()
to retrieve the current state dictionary of the optimizer. This dictionary is initially sharded across machines. - The function then uses
FSDP.flatten_sharded_optim_state_dict(optimizer_state_dict)
to gather the sharded state from all machines and flatten it into a single dictionary. - Finally, the function saves both the model's state dictionary and the flattened optimizer state dictionary, along with any other training state you want to capture, to the specified filename using
torch.save
.
- Make sure to only save the checkpoint from rank 0 to avoid duplicate files across distributed machines.
- This example assumes a single optimizer is used. If you have multiple optimizers, you'll need to flatten the state dictionary for each one.
- If you have a strong understanding of FSDP internals and distributed communication, you could manually gather the sharded optimizer state from all machines and concatenate them into a single dictionary. This involves using distributed communication primitives like
torch.distributed.all_gather
to collect state tensors from each machine and then combining them appropriately. However, this approach requires more low-level handling and is generally less recommended due to potential complexity and error-proneness.
- If you have a strong understanding of FSDP internals and distributed communication, you could manually gather the sharded optimizer state from all machines and concatenate them into a single dictionary. This involves using distributed communication primitives like
FSDP Checkpointing Utilities (For Newer PyTorch)
- If you're using a recent version of PyTorch with the
accelerate
library, you can leverage its checkpointing utilities designed for FSDP. Libraries likeaccelerate
provide higher-level abstractions that handle FSDP-specific operations, including state saving. You can use functions likeaccelerate.save_state
andaccelerate.load_state
to save and load the complete training state, including the model and FSDP-managed optimizer state. These functions internally take care of gathering and flattening the optimizer state, simplifying the process.
- If you're using a recent version of PyTorch with the
Alternative Distributed Training Techniques
- If FSDP's sharding mechanism doesn't suit your specific needs, you might explore other distributed training techniques in PyTorch. These techniques might have different approaches to optimizer state management. Some alternatives include:
- Distributed Data Parallel (DDP): This approach replicates the entire model across all machines, which can be memory-intensive for large models. However, optimizer state management is simpler as it's not sharded.
- Gradient Accumulation: This technique accumulates gradients across multiple mini-batches before updating the optimizer. It can help with memory constraints, but it requires specific coding patterns for gradient accumulation.
- If FSDP's sharding mechanism doesn't suit your specific needs, you might explore other distributed training techniques in PyTorch. These techniques might have different approaches to optimizer state management. Some alternatives include:
Choosing the Right Approach
The best approach depends on your specific context:
- Distributed Training Technique
If FSDP doesn't suit your needs, explore alternative distributed training techniques and their optimizer state management strategies. - Control
If you need more control over the distributed communication for gathering optimizer state, manual gathering might be an option (but with higher complexity). - Complexity
If simplicity is your priority, consider using FSDP checkpointing utilities fromaccelerate
(if applicable).