Extracting Non-Zero Element Indices: torch.argwhere vs. Alternatives


torch.argwhere Function

In PyTorch, torch.argwhere is a method used on a tensor to return a new tensor containing the indices of all non-zero elements in the input tensor. It's similar to NumPy's argwhere function.

Key Points

  • Non-Zero Elements
    Elements considered non-zero include positive or negative numbers, as well as very small floating-point values that aren't exactly zero due to limitations of floating-point representation.
  • Output
    Returns a new tensor of Long (integer) type with shape (N, M), where N is the number of non-zero elements in the input tensor and M is the number of dimensions (rank) of the input tensor. Each row in the output tensor represents the indices of a non-zero element.
  • Input
    Takes a single tensor as input.

Example

import torch

# Create a sample tensor
tensor = torch.tensor([[1, 0, 3], [0, 4, 0]])

# Get indices of non-zero elements using torch.argwhere
non_zero_indices = tensor.argwhere()

print(non_zero_indices)

This code will output:

tensor([[0, 0],
        [0, 1],
        [1, 1]])

As you can see, the output tensor shows the indices (row, column) of each non-zero element in the original tensor.

Comparison with torch.nonzero

  • torch.where can be used as a functional equivalent to torch.nonzero(condition, as_tuple=True).
  • torch.argwhere returns a tensor of indices, while torch.nonzero returns a tensor of indices as a tuple.
  • Calling torch.argwhere on a tensor located on the GPU (CUDA) will cause data to be transferred between the GPU and CPU, which can be slow. If you're working with tensors on the GPU, consider using torch.nonzero with the out argument to avoid unnecessary data transfers.


Finding Indices with a Threshold

import torch

# Sample tensor
tensor = torch.tensor([[-1, 2, 0.5], [0, -3, 4]])

# Threshold for non-zero elements (can be adjusted)
threshold = 0.1

# Create a mask to identify elements above the threshold
mask = torch.abs(tensor) > threshold

# Use torch.nonzero to get indices based on the mask
non_zero_like_threshold = torch.nonzero(mask)

# Print the indices
print(non_zero_like_threshold)

This code first defines a threshold and then creates a mask using torch.abs(tensor) > threshold. Finally, it uses torch.nonzero(mask) to get the indices of elements in the original tensor that exceed the threshold.

Multi-Dimensional Tensors

import torch

# 3D tensor
tensor = torch.tensor([[[1, 0], [2, 3]], [[4, 0], [5, 6]]])

# Get non-zero indices
non_zero_indices = tensor.argwhere()

# Print the indices
print(non_zero_indices)

This code creates a 3D tensor and uses tensor.argwhere() to obtain a tensor containing the indices of all non-zero elements, indicating their position along each dimension.

  • For performance considerations, especially when working with tensors on the GPU, consider using torch.nonzero with the out argument to avoid data transfers between CPU and GPU.


torch.nonzero

  • Example
  • Advantage
    Potentially more efficient, especially when working with tensors on the GPU. It avoids unnecessary data transfers between CPU and GPU as compared to torch.argwhere.
  • Functionality
    Retrieves indices of non-zero elements in a similar manner to torch.argwhere. It returns a tuple of tensors representing the indices along each dimension.
import torch

tensor = torch.tensor([[1, 0, 3], [0, 4, 0]])
non_zero_indices = torch.nonzero(tensor)
print(non_zero_indices)

Masking and Indexing

  • Example
  • Advantage
    Can be used for more complex filtering beyond just non-zero elements.
  • Functionality
    Creates a mask to identify non-zero elements and then uses indexing to extract their indices.
import torch

tensor = torch.tensor([[1, 0, 3], [0, 4, 0]])
mask = tensor != 0  # Create mask for non-zero elements
non_zero_indices = torch.where(mask)[0:2]  # Extract row and column indices
print(non_zero_indices)

Custom Loop (Less Common)

  • Example (not recommended for large tensors)
  • Advantage
    Provides granular control over the process, but can be less efficient for large tensors.
  • Functionality
    Iterate through the tensor and store indices of non-zero elements in a list or another tensor.
import torch

tensor = torch.tensor([[1, 0, 3], [0, 4, 0]])
non_zero_indices = []
for i in range(tensor.size(0)):
  for j in range(tensor.size(1)):
    if tensor[i, j] != 0:
      non_zero_indices.append([i, j])

print(non_zero_indices)

The best alternative depends on your specific use case:

  • Custom loops should be reserved for specific scenarios where you need very fine-grained control over the process, but keep in mind the potential performance implications.
  • If you require more complex filtering criteria, masking and indexing offer greater flexibility.
  • If you simply need the indices of non-zero elements and performance is crucial (especially on GPU), torch.nonzero is generally recommended.