Understanding PyTorch Upsampling: Moving Beyond torch.nn.functional.upsample


Functionality

  • Commonly used in tasks like image or feature map upscaling, often within generative models (e.g., GANs) or for processing data at different scales.
  • Increases the resolution (spatial dimensions) of a tensor.

Arguments

  • mode (str, optional): The algorithm used for upsampling. Supported modes are:
    • 'nearest' (default): Replicates nearest neighbor pixel values.
    • 'linear' (1D only): Performs linear interpolation.
    • 'bilinear' (2D only): Performs bilinear interpolation, commonly used for image upsampling.
    • 'bicubic' (2D or 3D): Performs bicubic interpolation, offering smoother results than bilinear but computationally more expensive.
    • 'trilinear' (3D only): Performs trilinear interpolation, the 3D equivalent of bilinear interpolation.
  • scale_factor (float or tuple): Optional multiplier for spatial size. If provided, overrides size. Values greater than 1 perform upsampling, while values less than 1 perform downsampling.
  • size (int or tuple): Desired output spatial size. Can be specified as a single integer for uniform scaling or a tuple for independent scaling per dimension.
  • input (Tensor): The input tensor to be upsampled.

Important Note

  • torch.nn.functional.upsample is deprecated in favor of the more general torch.nn.functional.interpolate function. It's recommended to use interpolate for future code.

Example (Bilinear Upsampling)

import torch
from torch import nn

input = torch.randn(1, 3, 10, 10)  # Batch size 1, 3 channels, 10x10 image
scale_factor = 2  # Upsample by a factor of 2

upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='bilinear')
print(upsampled.shape)  # Output: torch.Size([1, 3, 20, 20])
  • Consider using torch.nn.functional.interpolate instead of upsample for future development.
  • The choice of mode affects the quality and smoothness of the upsampled output. Bilinear interpolation is a good balance between quality and computational cost for most image upsampling tasks.
  • Upsampling introduces new data points, so the resulting tensor will have increased spatial dimensions.


import torch
from torch import nn

input = torch.randn(1, 3, 10, 10)  # Batch size 1, 3 channels, 10x10 image
scale_factor = 2

upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='nearest')
print(upsampled.shape)  # Output: torch.Size([1, 3, 20, 20])

This code performs nearest neighbor upsampling, which is computationally efficient but can result in blocky artifacts.

Bicubic Upsampling (2D)

import torch
from torch import nn

input = torch.randn(1, 3, 10, 10)  # Batch size 1, 3 channels, 10x10 image
scale_factor = 3

upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='bicubic')
print(upsampled.shape)  # Output: torch.Size([1, 3, 30, 30])

This code uses bicubic interpolation for smoother upsampling, but it's more computationally expensive than bilinear.

Trilinear Upsampling (3D)

import torch
from torch import nn

input = torch.randn(1, 3, 5, 10, 10)  # Batch size 1, 3 channels, 5x10x10 volume
scale_factor = 2

upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='trilinear')
print(upsampled.shape)  # Output: torch.Size([1, 3, 10, 20, 20])

This code performs trilinear interpolation for upsampling 3D volumes (e.g., video data).

Upsampling to a Specific Size

import torch
from torch import nn

input = torch.randn(1, 3, 10, 10)  # Batch size 1, 3 channels, 10x10 image
output_size = (20, 20)

upsampled = nn.functional.interpolate(input, size=output_size, mode='bilinear')
print(upsampled.shape)  # Output: torch.Size([1, 3, 20, 20])

This code upsamples the input to a specified output size (output_size).



torch.nn.functional.interpolate

This function offers a more general and powerful approach to upsampling (and downsampling) tensors. It provides the following advantages:

  • Additional Features
    Supports interpolation for different dimensions (1D, 2D, 3D), and includes options for alignment (align_corners) for certain modes.
  • Flexibility
    Supports various upsampling modes ('nearest', 'linear', 'bilinear', 'bicubic', 'trilinear') and allows specifying the target output size or scale factor.
import torch
from torch import nn

input = torch.randn(1, 3, 10, 10)  # Batch size 1, 3 channels, 10x10 image
scale_factor = 2

# Equivalent to deprecated upsample with bilinear mode
upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='bilinear')
print(upsampled.shape)  # Output: torch.Size([1, 3, 20, 20])