Understanding PyTorch Upsampling: Moving Beyond torch.nn.functional.upsample
Functionality
- Commonly used in tasks like image or feature map upscaling, often within generative models (e.g., GANs) or for processing data at different scales.
- Increases the resolution (spatial dimensions) of a tensor.
Arguments
mode
(str, optional): The algorithm used for upsampling. Supported modes are:- 'nearest' (default): Replicates nearest neighbor pixel values.
- 'linear' (1D only): Performs linear interpolation.
- 'bilinear' (2D only): Performs bilinear interpolation, commonly used for image upsampling.
- 'bicubic' (2D or 3D): Performs bicubic interpolation, offering smoother results than bilinear but computationally more expensive.
- 'trilinear' (3D only): Performs trilinear interpolation, the 3D equivalent of bilinear interpolation.
scale_factor
(float or tuple): Optional multiplier for spatial size. If provided, overridessize
. Values greater than 1 perform upsampling, while values less than 1 perform downsampling.size
(int or tuple): Desired output spatial size. Can be specified as a single integer for uniform scaling or a tuple for independent scaling per dimension.input
(Tensor): The input tensor to be upsampled.
Important Note
torch.nn.functional.upsample
is deprecated in favor of the more generaltorch.nn.functional.interpolate
function. It's recommended to useinterpolate
for future code.
Example (Bilinear Upsampling)
import torch
from torch import nn
input = torch.randn(1, 3, 10, 10) # Batch size 1, 3 channels, 10x10 image
scale_factor = 2 # Upsample by a factor of 2
upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='bilinear')
print(upsampled.shape) # Output: torch.Size([1, 3, 20, 20])
- Consider using
torch.nn.functional.interpolate
instead ofupsample
for future development. - The choice of
mode
affects the quality and smoothness of the upsampled output. Bilinear interpolation is a good balance between quality and computational cost for most image upsampling tasks. - Upsampling introduces new data points, so the resulting tensor will have increased spatial dimensions.
import torch
from torch import nn
input = torch.randn(1, 3, 10, 10) # Batch size 1, 3 channels, 10x10 image
scale_factor = 2
upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='nearest')
print(upsampled.shape) # Output: torch.Size([1, 3, 20, 20])
This code performs nearest neighbor upsampling, which is computationally efficient but can result in blocky artifacts.
Bicubic Upsampling (2D)
import torch
from torch import nn
input = torch.randn(1, 3, 10, 10) # Batch size 1, 3 channels, 10x10 image
scale_factor = 3
upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='bicubic')
print(upsampled.shape) # Output: torch.Size([1, 3, 30, 30])
This code uses bicubic interpolation for smoother upsampling, but it's more computationally expensive than bilinear.
Trilinear Upsampling (3D)
import torch
from torch import nn
input = torch.randn(1, 3, 5, 10, 10) # Batch size 1, 3 channels, 5x10x10 volume
scale_factor = 2
upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='trilinear')
print(upsampled.shape) # Output: torch.Size([1, 3, 10, 20, 20])
This code performs trilinear interpolation for upsampling 3D volumes (e.g., video data).
Upsampling to a Specific Size
import torch
from torch import nn
input = torch.randn(1, 3, 10, 10) # Batch size 1, 3 channels, 10x10 image
output_size = (20, 20)
upsampled = nn.functional.interpolate(input, size=output_size, mode='bilinear')
print(upsampled.shape) # Output: torch.Size([1, 3, 20, 20])
This code upsamples the input to a specified output size (output_size
).
torch.nn.functional.interpolate
This function offers a more general and powerful approach to upsampling (and downsampling) tensors. It provides the following advantages:
- Additional Features
Supports interpolation for different dimensions (1D, 2D, 3D), and includes options for alignment (align_corners
) for certain modes. - Flexibility
Supports various upsampling modes ('nearest'
,'linear'
,'bilinear'
,'bicubic'
,'trilinear'
) and allows specifying the target output size or scale factor.
import torch
from torch import nn
input = torch.randn(1, 3, 10, 10) # Batch size 1, 3 channels, 10x10 image
scale_factor = 2
# Equivalent to deprecated upsample with bilinear mode
upsampled = nn.functional.interpolate(input, scale_factor=scale_factor, mode='bilinear')
print(upsampled.shape) # Output: torch.Size([1, 3, 20, 20])