Managing Memory for PyTorch MPS Models: Understanding torch.mps.set_per_process_memory_fraction


Functionality

  • It sets a limit on the amount of memory a single PyTorch process can use on the MPS device.
  • This function is designed to manage memory allocation for the MPS (Metal Performance Shaders) backend in PyTorch on Apple devices.

Arguments

  • fraction (float): This value represents the proportion of the recommended maximum device memory that your process is allowed to consume. It must be between 0 and 2 (inclusive).
    • 0: Allocates no memory on the MPS device (effectively disables it for the process).
    • 1: Sets the memory limit to the recommended maximum device memory.
    • Values between 0 and 1: Allocate a proportional fraction of the recommended memory.
    • Values greater than 1: Might be allowed by PyTorch, but exceeding the recommended limit could lead to performance issues or instability.

How it Works

  1. Retrieving Recommended Memory
    The function queries the Metal API to determine the device's recommendedMaxWorkingSetSize. This represents the optimal amount of memory the device suggests for efficient MPS operations.
  2. Setting Memory Limit
    The provided fraction is multiplied by the recommendedMaxWorkingSetSize to calculate the allowed memory allocation for the PyTorch process.
  3. Enforcing Limit
    PyTorch's MPS memory allocator ensures that memory usage doesn't exceed this limit. If a process attempts to allocate more memory than allowed, an out-of-memory error is raised.

Benefits

  • Improved Stability
    By setting appropriate memory limits, you can avoid crashes or performance degradation due to excessive memory usage on the MPS device.
  • Memory Management
    This function helps prevent memory exhaustion, especially when running multiple PyTorch processes or dealing with large models/datasets.

Example

import torch.mps

# Allocate 75% of the recommended MPS device memory
torch.mps.set_per_process_memory_fraction(0.75)

Cautions

  • Setting a too high fraction could lead to memory exhaustion if other processes or applications on the system are also using significant memory.
  • Setting a too low fraction might restrict the MPS device's capabilities and hinder performance.
  • It can significantly improve the performance of PyTorch models, especially on Apple devices with powerful GPUs.
  • MPS provides an interface for leveraging Apple's Metal API for accelerated computations on the device's GPU.


Setting a Memory Limit

import torch.mps

# Allocate 50% of the recommended MPS device memory
memory_fraction = 0.5
torch.mps.set_per_process_memory_fraction(memory_fraction)

# Your PyTorch code using MPS device
# ...

Checking Available Memory

import torch.mps

# Set a memory limit (optional)
# memory_fraction = 0.75
# torch.mps.set_per_process_memory_fraction(memory_fraction)

# Get current allocated memory on the MPS device
allocated_memory = torch.mps.current_allocated_memory()
print(f"Currently allocated memory on MPS device: {allocated_memory} bytes")

# Get total GPU memory allocated by Metal driver
total_memory = torch.mps.driver_allocated_memory()
print(f"Total GPU memory allocated by Metal: {total_memory} bytes")

# Your PyTorch code using MPS device
# ...

Handling Out-of-Memory Errors (Optional)

import torch.mps

try:
  # Set a potentially high memory limit (might cause out-of-memory error)
  memory_fraction = 1.2
  torch.mps.set_per_process_memory_fraction(memory_fraction)

  # Your PyTorch code using MPS device (might raise error)
  # ...
except RuntimeError as e:
  if "MPS: out of memory" in str(e):
    print("Out-of-memory error occurred! Reducing memory usage...")
    # Reduce memory fraction or handle the error gracefully
    memory_fraction = 0.75
    torch.mps.set_per_process_memory_fraction(memory_fraction)
    # ... (Retry or adjust your code)
  else:
    raise e  # Re-raise other errors

# Your PyTorch code using MPS device (with reduced memory usage)
# ...
  • The third example demonstrates handling potential out-of-memory errors, but it's a simplified approach. In real-world scenarios, you might want to implement more robust error handling mechanisms depending on your specific use case.


Monitor Memory Usage

  • Based on the observed usage, you can dynamically adjust model parameters, batch sizes, or other aspects of your code to reduce memory consumption.
  • Leverage functions like torch.cuda.memory_summary() (even though it's for CUDA, it can provide insights into MPS memory as well) or torch.mps.current_allocated_memory() to track memory usage during your training or inference process.

Reduce Model Complexity

  • By reducing model size, you inherently lower the memory required for computations.
  • Consider techniques like model pruning, quantization, or knowledge distillation to make your model smaller and more memory-efficient. These approaches can often achieve comparable performance with a smaller memory footprint.

Manage Batch Size

  • Find the optimal batch size that balances memory constraints with training efficiency.
  • Experiment with different batch sizes. A smaller batch size leads to less memory usage at the cost of potentially slower training.

Utilize Automatic Mixed Precision (AMP)

  • If your hardware and PyTorch version support it, enable AMP (Automatic Mixed Precision). AMP allows using a mix of data types (e.g., float16, float32) during training, reducing memory consumption compared to using only float32.

Leverage Cloud Resources

  • If your local machine's GPU memory is insufficient, consider using cloud platforms like Google Colab, Amazon SageMaker, or Microsoft Azure that offer instances with powerful GPUs and larger memory capacities.
  • While not ideal for performance-critical tasks due to slower computation compared to GPUs, training on the CPU can be a viable option for small models or when memory limitations are severe.