Understanding Asynchronous Communication with torch.distributed.irecv() in PyTorch


Concept

  • It's part of the asynchronous communication paradigm, allowing processes to overlap communication with computation for potentially faster training.
  • torch.distributed.irecv() (asynchronous receive) is a function used in distributed PyTorch programs to initiate the receiving of a tensor from another process in a distributed training or inference setup.

Key Points

  • Arguments
    • tensor (required): The pre-allocated tensor (with correct shape and dtype) to receive the data into.
    • src (required): The rank (integer) of the sending process.
  • Handle-based
    irecv() returns a handle object that represents the ongoing receive operation. You can use this handle to later check if the receive has finished and retrieve the received tensor.
  • Non-blocking
    Unlike its synchronous counterpart torch.distributed.recv(), irecv() doesn't block the calling process. The process can continue with other computations while waiting for the receive operation to complete.

Usage Example

import torch
import torch.distributed as dist

# ... (distributed process initialization)

# Rank 0 sends a tensor
if dist.get_rank() == 0:
    send_tensor = torch.ones(10)
    dist.isend(send_tensor, dst=1)

# Rank 1 receives the tensor asynchronously
if dist.get_rank() == 1:
    recv_tensor = torch.zeros(10)
    recv_handle = dist.irecv(recv_tensor, src=0)

    # Perform other computations while waiting for receive
    # ...

    # Check if receive has finished (non-blocking)
    if dist.is_recv_completion_available(recv_handle):
        dist.wait(recv_handle)  # Wait for the receive to complete (optional)
        print(recv_tensor)  # Now recv_tensor contains the received data
  • You can use dist.is_recv_completion_available(recv_handle) to check if the receive has finished at any point after initiating it with irecv(). Calling dist.wait(recv_handle) will block the process until the receive completes.
  • It's essential to ensure the receiving tensor has the correct shape and data type to match the sending tensor.
  • irecv() currently supports CPU tensors only (unless MPI backend allows for GPU communication).


Ring Allreduce (Simplified)

This code implements a simplified version of ring allreduce using asynchronous communication. Processes pass a tensor around the ring, modifying it at each step.

import torch
import torch.distributed as dist

def ring_allreduce_async(tensor):
    world_size = dist.get_world_size()
    rank = dist.get_rank()

    for i in range(world_size - 1):
        send_rank = (rank + i + 1) % world_size
        recv_rank = (rank - i - 1) % world_size

        # Send asynchronously
        dist.isend(tensor, dst=send_rank)

        # Receive asynchronously (can overlap with computation on tensor)
        recv_handle = dist.irecv(tensor, src=recv_rank)

        # Perform computation on the tensor (modify it)
        # ...

        # Wait for receive to finish (optional)
        dist.wait(recv_handle)

Overlapping Communication with Computation

This code demonstrates overlapping communication with computation. Process 0 sends a tensor to process 1, while process 1 performs some calculations before receiving and then performs more calculations after receiving.

import torch
import torch.distributed as dist
import time

def main():
    if dist.get_rank() == 0:
        send_tensor = torch.ones(10)
        dist.isend(send_tensor, dst=1)

        # Perform other computations while sending
        time.sleep(2)  # Simulate some work

    else:
        recv_tensor = torch.zeros(10)
        recv_handle = dist.irecv(recv_tensor, src=0)

        # Perform calculations before receive
        time.sleep(1)  # Simulate some work

        # Check if receive has finished (optional)
        if dist.is_recv_completion_available(recv_handle):
            dist.wait(recv_handle)

        # Perform calculations after receive
        time.sleep(3)  # Simulate some work

if __name__ == "__main__":
    # ... (distributed process initialization)
    main()


torch.distributed.recv() (Synchronous Receive)

  • This is the synchronous counterpart of irecv(). It blocks the calling process until the data is received, ensuring the receive has completed before the program continues. While it simplifies the flow, it can stall the process and hinder performance in scenarios where overlapping communication and computation is desirable.

Remote Procedure Calls (RPC) with torch.distributed.rpc (PyTorch 1.4+)

  • While not strictly asynchronous like irecv(), you can combine RPC with asynchronous operations within the remote function for more granular control.
  • This allows you to define functions on one process and execute them on another, potentially with arguments and returning results.
  • If you need a more flexible communication paradigm, consider using Remote Procedure Calls (RPC) with torch.distributed.rpc.

Custom Implementations using Lower-Level Libraries

  • These libraries offer more fine-grained control over communication details, requiring more in-depth knowledge of distributed programming.
  • For advanced use cases or specific communication patterns, you might explore building custom communication logic using lower-level libraries like MPI, NCCL, or Gloo.
  • Opt for custom implementations using lower-level libraries when more control and customization are essential, understanding the complexity involved.
  • Consider torch.distributed.rpc for scenarios requiring more flexible function execution across processes.
  • Use torch.distributed.recv() when simplicity and guaranteed completion before proceeding are preferred.
  • If asynchronous communication with overlapping computation is crucial, torch.distributed.irecv() remains a good choice.