Understanding MultiplicativeLR for Learning Rate Optimization in PyTorch


Purpose

  • It allows you to implement custom learning rate decay or growth strategies.
  • This scheduler dynamically adjusts the learning rate of each parameter group in an optimizer throughout the training process.

How it Works

    • You create a MultiplicativeLR instance, passing:
      • optimizer: The PyTorch optimizer you're using (e.g., Adam, SGD).
      • lr_lambda: A function (or list of functions) that calculates the multiplicative factor for the learning rate at each epoch.
        • The function takes an integer epoch as input and returns a float representing the factor.
        • Alternatively, you can provide a list of functions, one for each parameter group in the optimizer.
      • last_epoch (optional): The index of the last completed epoch (usually set to -1 initially).
      • verbose (optional): A boolean flag for printing learning rates (deprecated).
  1. Step Function

    • Called after each training epoch to update the learning rates:
      • The scheduler iterates through the optimizer's parameter groups.
      • For each group, it:
        • Retrieves the current learning rate (lr).
        • Calls the corresponding lr_lambda function with the current epoch to get the multiplicative factor (factor).
        • Updates the learning rate by multiplying the current lr with factor.

Code Example

import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR

def lr_lambda(epoch):
    return 0.95  # Decay learning rate by 5% every epoch

model = ...  # Your PyTorch model
optimizer = optim.Adam(model.parameters())
scheduler = MultiplicativeLR(optimizer, lr_lambda=lr_lambda)

# Train your model...

for epoch in range(num_epochs):
    # ... training loop ...
    scheduler.step()  # Update learning rates after each epoch

Customization and Considerations

  • Experiment with different factors and learning rate schedules to find the optimal configuration for your problem.
  • Be mindful of over-decaying the learning rate, which can hinder convergence.
  • Consider using techniques like warmup or cyclical learning rates in conjunction with MultiplicativeLR.
  • You can define different learning rate decay/growth functions in lr_lambda to suit your needs.


Example 1: Exponential Decay with Different Rates for Different Parameter Groups

This example shows how to define separate learning rate decay functions for distinct parameter groups in the optimizer:

import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR

def lr_lambda1(epoch):
    return 0.98  # Decay by 2% for the first parameter group

def lr_lambda2(epoch):
    return 0.95  # Decay by 5% for the second parameter group

model = ...  # Your PyTorch model
optimizer = optim.Adam([{'params': model.fc1.parameters()}, {'params': model.fc2.parameters()}])
scheduler = MultiplicativeLR(optimizer, lr_lambda=[lr_lambda1, lr_lambda2])

# Train your model...

for epoch in range(num_epochs):
    # ... training loop ...
    scheduler.step()  # Update learning rates after each epoch

Example 2: Linear Warmup

This example implements a linear warmup strategy where the learning rate gradually increases from 0 to the initial learning rate over a specified number of warmup epochs:

import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR

def lr_lambda(epoch, warmup_epochs):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        return 0.95  # Decay after warmup

warmup_epochs = 5
model = ...  # Your PyTorch model
optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = MultiplicativeLR(optimizer, lr_lambda=lambda epoch: lr_lambda(epoch, warmup_epochs))

# Train your model...

for epoch in range(num_epochs):
    # ... training loop ...
    scheduler.step()  # Update learning rates after each epoch


StepLR

  • Example:
  • Useful for simple learning rate reductions at predefined intervals.
  • Decreases the learning rate by a factor (gamma) every specified number of epochs (step_size).
from torch.optim.lr_scheduler import StepLR

scheduler = StepLR(optimizer, step_size=10, gamma=0.1)  # Reduce lr by 10% every 10 epochs

MultiStepLR

  • Example:
  • Offers more control over when to decrease the learning rate.
  • Decays the learning rate by a factor (gamma) at specific epochs defined in a list (milestones).
from torch.optim.lr_scheduler import MultiStepLR

scheduler = MultiStepLR(optimizer, milestones=[20, 40], gamma=0.1)  # Reduce lr at epochs 20 and 40

ReduceLROnPlateau

  • Example:
  • Adapts learning rate based on validation performance.
  • Monitors a validation metric (e.g., validation loss) and reduces learning rate if it plateaus for a specified number of epochs (patience).
from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5)  # Reduce lr by 10% if validation loss doesn't improve in 5 epochs

CosineAnnealingLR

  • Example:
  • Often helpful for fine-tuning after initial learning rate decay.
  • Gradually reduces the learning rate following a cosine curve from the initial value to a minimum (eta_min) over a specified number of epochs (T_max).
from torch.optim.lr_scheduler import CosineAnnealingLR

scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.001)  # Cosine annealing with T_max=10 epochs and minimum lr of 0.001

CosineAnnealingWarmRestarts

  • Can be useful for training models for very long durations.
  • Similar to CosineAnnealingLR but restarts the cosine annealing cycle multiple times with a potentially lower eta_min each time.

OneCycleLR

  • Can be effective for models that benefit from aggressive learning rate changes.
  • Implements a one-cycle policy with a rising learning rate up to a peak, followed by a gradual decay to a minimum learning rate.

Choosing the Right Scheduler

The best scheduler depends on your specific problem and training dynamics. Here are some additional considerations:

  • Training Experience
    Experiment with different schedulers to find the one that works best for your task.
  • Model Architecture
    Some models (e.g., CNNs) may be more sensitive to initial learning rate selection.
  • Problem Type
    For simpler problems, StepLR or MultiStepLR might suffice. More complex problems might benefit from schedulers that adapt to validation performance or follow annealing strategies.