Beyond torch.median: Alternative Approaches for Median Calculation in PyTorch


Purpose

  • The median is the "middle" value when the data is sorted in ascending order.
  • Calculates the median value(s) along a specified dimension of a PyTorch tensor.

Behavior

  • Returns a namedtuple containing two elements:
    • values: A tensor containing the median value(s) for each row (along dim).
    • indices: A tensor containing the index(es) of the median element(s) for each row (along dim).
  • Takes an optional dim argument (defaulting to 0) that specifies the dimension along which to compute the median.
  • Operates on a single tensor input.

Example

import torch

# Create a sample tensor
data = torch.tensor([[1, 7, 3], [4, 5, 6], [2, 8, 9]])

# Calculate median along the first dimension (rows)
medians, indices = torch.median(data, dim=0)

print(medians)  # tensor([3., 6., 7.])
print(indices)  # tensor([1, 1, 0])
  • indices contains the index of the median element in each row: [1, 1, 0]. For example, in the first row [1, 7, 3], the median is 3, which is at index 1.
  • medians contains the median for each row: [3, 6, 7].
  • torch.median is generally less performant than vectorized operations like sorting and slicing, especially for large tensors. Consider alternative approaches for performance-critical scenarios.
  • If input contains NaN values (Not a Number), use torch.nanmedian instead, which excludes NaNs from the calculation.
  • For tensors with an even number of elements along the specified dimension, torch.median returns the lower of the two middle values.


Calculating Median of the Entire Tensor

import torch

# Create a sample tensor
data = torch.tensor([3, 1, 4, 1, 5])

# Calculate median without specifying a dimension (across all elements)
median = torch.median(data)

print(median)  # tensor(3.)

In this case, dim is omitted (or set to None), so the median is calculated across all elements, resulting in a single value (3).

Calculating Median Along Multiple Dimensions

import torch

# Create a 3D sample tensor
data = torch.arange(24).reshape(2, 3, 4)

# Calculate median along the second dimension (columns)
medians, indices = torch.median(data, dim=1)

print(medians.shape)  # torch.Size([2, 4])
print(indices.shape)   # torch.Size([2, 4])

# Now medians and indices contain values for each row (across columns)

This example calculates the median along columns (dimension 1). medians and indices will have the same shape ([2, 4]), with medians for each row (across columns) and corresponding indices.

import torch

data = torch.tensor([[1, 7, 3], [4, 5, 6]])

# Calculate median along the first dimension, keeping the dimension
medians, indices = torch.median(data, dim=0, keepdim=True)

print(medians.shape)  # torch.Size([1, 3])
print(indices.shape)   # torch.Size([1, 3])


torch.quantile for Median with More Control

  • You can specify the desired quantile (between 0 and 1) to calculate other percentiles as well.
  • This offers more control over handling edge cases like even-numbered elements.
  • Use torch.quantile with q=0.5 to calculate the median.
import torch

data = torch.tensor([1, 7, 3])

# Calculate median using quantile
median = torch.quantile(data, q=0.5)

print(median)  # tensor(3.)

Custom Median Function (For Specific Use Cases)

  • This might be useful for handling specific edge cases or incorporating additional logic.
  • If you need more customization or control over median calculation, consider writing a custom function.

Sorting and Slicing (Performance-Critical Scenarios)

  • This can be more efficient than torch.median in some cases. However, it requires more manual steps.
  • For very large tensors where performance is crucial, consider using vectorized operations like sorting and slicing.
import torch

data = torch.randn(10000)

# Calculate median using sorting and slicing (manual approach)
sorted_data, _ = torch.sort(data, dim=0)
median_index = sorted_data.shape[0] // 2
median = sorted_data[median_index]

# This approach might be faster for large tensors but requires more steps
  • For very large tensors and performance-critical scenarios, explore vectorized operations like sorting and slicing, but be aware of the additional manual steps involved.
  • Consider a custom function if you have specific requirements or logic to incorporate.
  • Use torch.quantile if you need more control over handling even-numbered elements or want to calculate other percentiles.
  • If you need basic median calculation and performance isn't a major concern, torch.median is a good choice.