Efficient Object Synchronization in Distributed PyTorch with torch.distributed.broadcast_object_list()


Purpose

  • torch.distributed.broadcast_object_list() facilitates this synchronization by broadcasting a list of Python objects from a designated source rank (usually rank 0) to all other ranks in the distributed process group.
  • In distributed PyTorch training, where you train a model across multiple processes or machines (nodes), it's often necessary to synchronize parameters or other objects between these processes.

How it Works

    • Before using broadcast_object_list(), you must initialize the distributed process group using either torch.distributed.init_process_group() or torch.distributed.device_mesh.init_device_mesh(). This ensures all processes are coordinated and ready to communicate.
  1. Object Preparation

    • Create a list of Python objects (object_list) that you want to broadcast.
    • Crucially
      All objects in this list must be picklable. Pickling allows Python objects to be converted into a byte stream that can be transmitted across processes.
  2. Broadcast Operation

    • Call torch.distributed.broadcast_object_list(object_list, src=0).
    • Parameters
      • object_list: The list of objects to broadcast.
      • src (optional): The rank of the source process that holds the original list. By default, it's 0 (rank 0 broadcasts).
  3. Synchronization

    • Each process in the group participates in a communication step to receive the broadcasted objects.
    • After this step, all processes will have a copy of the original object_list from rank 0.

Example

import torch.distributed as dist

# ... (distributed process group initialization)

if dist.get_rank() == 0:
    objects = [torch.tensor([1, 2, 3]), "hello"]  # Picklable objects
else:
    objects = [None, None]

dist.broadcast_object_list(objects, src=0)

# Now, all processes will have `objects` containing the broadcasted tensors and string
print(objects)

Important Considerations

  • Custom class objects might require additional pickling steps to ensure their attributes are preserved during transmission.
  • broadcast_object_list() only works within a distributed process group. It has no effect in a non-distributed environment.

Alternatives for Non-Picklable Objects

If you have non-picklable objects, consider these alternatives:

  • Manual Communication
    Implement custom communication logic using torch.distributed.send() and torch.distributed.recv() if more control over the communication pattern is needed.
  • Serialization
    Use libraries like dill for more flexible serialization options.


Broadcasting Tensors and Other Data

import torch.distributed as dist

# ... (distributed process group initialization)

if dist.get_rank() == 0:
    objects = [torch.tensor([1, 2, 3]), "hello", [4, 5, 6]]  # Tensors and lists are picklable
else:
    objects = [None, None, None]

dist.broadcast_object_list(objects, src=0)

print(dist.get_rank(), objects)

This code broadcasts a list containing a tensor, a string, and another list. All processes will have the same list after the broadcast.

Broadcasting with Custom Class (Manual Pickling)

import torch.distributed as dist
import pickle

class MyClass:
    def __init__(self, value):
        self.value = value

    def __getstate__(self):
        # Define custom pickling logic (optional)
        return {'value': self.value}

# ... (distributed process group initialization)

if dist.get_rank() == 0:
    obj = MyClass(10)
    data_to_send = pickle.dumps(obj)  # Manually pickle the custom class
else:
    data_to_send = None

dist.broadcast_object_list([data_to_send], src=0)

if data_to_send is not None:
    obj = pickle.loads(data_to_send)  # Unpickle on receiving processes
    print(dist.get_rank(), obj.value)

This code demonstrates broadcasting a custom class object MyClass. It defines a __getstate__ method (optional) to control the pickling behavior and manually pickles the object on rank 0 before broadcasting. Receiving processes unpickle the data to retrieve the object.

Sending Different Data to Different Ranks (Conditional Broadcast)

import torch.distributed as dist

# ... (distributed process group initialization)

if dist.get_rank() == 0:
    objects = [torch.tensor([1, 2, 3]), None]  # Rank 0 gets the tensor
else:
    objects = [None, torch.tensor([4, 5, 6])]  # Other ranks get a different tensor

dist.broadcast_object_list(objects, src=0)

print(dist.get_rank(), objects)

This example shows how to conditionally send different data to different ranks. Rank 0 broadcasts its own tensor, while other ranks broadcast None and receive the tensor from rank 0.



Serialization for Non-Picklable Objects

import torch.distributed as dist
import dill

class MyClass:
    # ... (custom class definition)

# ... (distributed process group initialization)

if dist.get_rank() == 0:
    obj = MyClass(10)
    data_to_send = dill.dumps(obj)  # Serialize the object using dill
else:
    data_to_send = None

dist.broadcast_object_list([data_to_send], src=0)

if data_to_send is not None:
    obj = dill.loads(data_to_send)  # Deserialize on receiving processes
    print(dist.get_rank(), obj.value)

Manual Communication with torch.distributed.send() and torch.distributed.recv()

  • If you need more fine-grained control over the communication pattern or have specific communication requirements beyond broadcasting, consider using torch.distributed.send() and torch.distributed.recv(). These functions allow you to send and receive data directly between specific ranks:
import torch.distributed as dist

# ... (distributed process group initialization)

if dist.get_rank() == 0:
    data_to_send = torch.tensor([1, 2, 3])
    for rank in range(1, dist.get_world_size()):
        dist.send(data_to_send, rank=rank)
else:
    data_to_receive = torch.empty_like(torch.tensor([0]))
    dist.recv(data_to_receive, rank=0)
    print(dist.get_rank(), data_to_receive)

This code manually sends the tensor from rank 0 to all other ranks using send(). Receiving ranks use recv() to receive the data.

  • PyTorch also provides a distributed key-value store through torch.distributed. While not directly related to object broadcasting, you can use it to store and retrieve key-value pairs across processes. This can be useful for storing shared configuration data or intermediate results. However, for broadcasting objects, the previous options are more suitable.