Efficient Object Synchronization in Distributed PyTorch with torch.distributed.broadcast_object_list()
Purpose
torch.distributed.broadcast_object_list()
facilitates this synchronization by broadcasting a list of Python objects from a designated source rank (usually rank 0) to all other ranks in the distributed process group.- In distributed PyTorch training, where you train a model across multiple processes or machines (nodes), it's often necessary to synchronize parameters or other objects between these processes.
How it Works
- Before using
broadcast_object_list()
, you must initialize the distributed process group using eithertorch.distributed.init_process_group()
ortorch.distributed.device_mesh.init_device_mesh()
. This ensures all processes are coordinated and ready to communicate.
- Before using
Object Preparation
- Create a list of Python objects (
object_list
) that you want to broadcast. - Crucially
All objects in this list must be picklable. Pickling allows Python objects to be converted into a byte stream that can be transmitted across processes.
- Create a list of Python objects (
Broadcast Operation
- Call
torch.distributed.broadcast_object_list(object_list, src=0)
. - Parameters
object_list
: The list of objects to broadcast.src
(optional): The rank of the source process that holds the original list. By default, it's 0 (rank 0 broadcasts).
- Call
Synchronization
- Each process in the group participates in a communication step to receive the broadcasted objects.
- After this step, all processes will have a copy of the original
object_list
from rank 0.
Example
import torch.distributed as dist
# ... (distributed process group initialization)
if dist.get_rank() == 0:
objects = [torch.tensor([1, 2, 3]), "hello"] # Picklable objects
else:
objects = [None, None]
dist.broadcast_object_list(objects, src=0)
# Now, all processes will have `objects` containing the broadcasted tensors and string
print(objects)
Important Considerations
- Custom class objects might require additional pickling steps to ensure their attributes are preserved during transmission.
broadcast_object_list()
only works within a distributed process group. It has no effect in a non-distributed environment.
Alternatives for Non-Picklable Objects
If you have non-picklable objects, consider these alternatives:
- Manual Communication
Implement custom communication logic usingtorch.distributed.send()
andtorch.distributed.recv()
if more control over the communication pattern is needed. - Serialization
Use libraries likedill
for more flexible serialization options.
Broadcasting Tensors and Other Data
import torch.distributed as dist
# ... (distributed process group initialization)
if dist.get_rank() == 0:
objects = [torch.tensor([1, 2, 3]), "hello", [4, 5, 6]] # Tensors and lists are picklable
else:
objects = [None, None, None]
dist.broadcast_object_list(objects, src=0)
print(dist.get_rank(), objects)
This code broadcasts a list containing a tensor, a string, and another list. All processes will have the same list after the broadcast.
Broadcasting with Custom Class (Manual Pickling)
import torch.distributed as dist
import pickle
class MyClass:
def __init__(self, value):
self.value = value
def __getstate__(self):
# Define custom pickling logic (optional)
return {'value': self.value}
# ... (distributed process group initialization)
if dist.get_rank() == 0:
obj = MyClass(10)
data_to_send = pickle.dumps(obj) # Manually pickle the custom class
else:
data_to_send = None
dist.broadcast_object_list([data_to_send], src=0)
if data_to_send is not None:
obj = pickle.loads(data_to_send) # Unpickle on receiving processes
print(dist.get_rank(), obj.value)
This code demonstrates broadcasting a custom class object MyClass
. It defines a __getstate__
method (optional) to control the pickling behavior and manually pickles the object on rank 0 before broadcasting. Receiving processes unpickle the data to retrieve the object.
Sending Different Data to Different Ranks (Conditional Broadcast)
import torch.distributed as dist
# ... (distributed process group initialization)
if dist.get_rank() == 0:
objects = [torch.tensor([1, 2, 3]), None] # Rank 0 gets the tensor
else:
objects = [None, torch.tensor([4, 5, 6])] # Other ranks get a different tensor
dist.broadcast_object_list(objects, src=0)
print(dist.get_rank(), objects)
This example shows how to conditionally send different data to different ranks. Rank 0 broadcasts its own tensor, while other ranks broadcast None
and receive the tensor from rank 0.
Serialization for Non-Picklable Objects
import torch.distributed as dist
import dill
class MyClass:
# ... (custom class definition)
# ... (distributed process group initialization)
if dist.get_rank() == 0:
obj = MyClass(10)
data_to_send = dill.dumps(obj) # Serialize the object using dill
else:
data_to_send = None
dist.broadcast_object_list([data_to_send], src=0)
if data_to_send is not None:
obj = dill.loads(data_to_send) # Deserialize on receiving processes
print(dist.get_rank(), obj.value)
Manual Communication with torch.distributed.send() and torch.distributed.recv()
- If you need more fine-grained control over the communication pattern or have specific communication requirements beyond broadcasting, consider using
torch.distributed.send()
andtorch.distributed.recv()
. These functions allow you to send and receive data directly between specific ranks:
import torch.distributed as dist
# ... (distributed process group initialization)
if dist.get_rank() == 0:
data_to_send = torch.tensor([1, 2, 3])
for rank in range(1, dist.get_world_size()):
dist.send(data_to_send, rank=rank)
else:
data_to_receive = torch.empty_like(torch.tensor([0]))
dist.recv(data_to_receive, rank=0)
print(dist.get_rank(), data_to_receive)
This code manually sends the tensor from rank 0 to all other ranks using send()
. Receiving ranks use recv()
to receive the data.
- PyTorch also provides a distributed key-value store through
torch.distributed
. While not directly related to object broadcasting, you can use it to store and retrieve key-value pairs across processes. This can be useful for storing shared configuration data or intermediate results. However, for broadcasting objects, the previous options are more suitable.