Understanding MultiplicativeLR for Learning Rate Optimization in PyTorch
Purpose
- It allows you to implement custom learning rate decay or growth strategies.
- This scheduler dynamically adjusts the learning rate of each parameter group in an optimizer throughout the training process.
How it Works
- You create a
MultiplicativeLR
instance, passing:optimizer
: The PyTorch optimizer you're using (e.g.,Adam
,SGD
).lr_lambda
: A function (or list of functions) that calculates the multiplicative factor for the learning rate at each epoch.- The function takes an integer
epoch
as input and returns a float representing the factor. - Alternatively, you can provide a list of functions, one for each parameter group in the optimizer.
- The function takes an integer
last_epoch
(optional): The index of the last completed epoch (usually set to -1 initially).verbose
(optional): A boolean flag for printing learning rates (deprecated).
- You create a
Step Function
- Called after each training epoch to update the learning rates:
- The scheduler iterates through the optimizer's parameter groups.
- For each group, it:
- Retrieves the current learning rate (
lr
). - Calls the corresponding
lr_lambda
function with the currentepoch
to get the multiplicative factor (factor
). - Updates the learning rate by multiplying the current
lr
withfactor
.
- Retrieves the current learning rate (
- Called after each training epoch to update the learning rates:
Code Example
import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR
def lr_lambda(epoch):
return 0.95 # Decay learning rate by 5% every epoch
model = ... # Your PyTorch model
optimizer = optim.Adam(model.parameters())
scheduler = MultiplicativeLR(optimizer, lr_lambda=lr_lambda)
# Train your model...
for epoch in range(num_epochs):
# ... training loop ...
scheduler.step() # Update learning rates after each epoch
Customization and Considerations
- Experiment with different factors and learning rate schedules to find the optimal configuration for your problem.
- Be mindful of over-decaying the learning rate, which can hinder convergence.
- Consider using techniques like warmup or cyclical learning rates in conjunction with
MultiplicativeLR
. - You can define different learning rate decay/growth functions in
lr_lambda
to suit your needs.
Example 1: Exponential Decay with Different Rates for Different Parameter Groups
This example shows how to define separate learning rate decay functions for distinct parameter groups in the optimizer:
import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR
def lr_lambda1(epoch):
return 0.98 # Decay by 2% for the first parameter group
def lr_lambda2(epoch):
return 0.95 # Decay by 5% for the second parameter group
model = ... # Your PyTorch model
optimizer = optim.Adam([{'params': model.fc1.parameters()}, {'params': model.fc2.parameters()}])
scheduler = MultiplicativeLR(optimizer, lr_lambda=[lr_lambda1, lr_lambda2])
# Train your model...
for epoch in range(num_epochs):
# ... training loop ...
scheduler.step() # Update learning rates after each epoch
Example 2: Linear Warmup
This example implements a linear warmup strategy where the learning rate gradually increases from 0 to the initial learning rate over a specified number of warmup epochs:
import torch.optim as optim
from torch.optim.lr_scheduler import MultiplicativeLR
def lr_lambda(epoch, warmup_epochs):
if epoch < warmup_epochs:
return (epoch + 1) / warmup_epochs
else:
return 0.95 # Decay after warmup
warmup_epochs = 5
model = ... # Your PyTorch model
optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = MultiplicativeLR(optimizer, lr_lambda=lambda epoch: lr_lambda(epoch, warmup_epochs))
# Train your model...
for epoch in range(num_epochs):
# ... training loop ...
scheduler.step() # Update learning rates after each epoch
StepLR
- Example:
- Useful for simple learning rate reductions at predefined intervals.
- Decreases the learning rate by a factor (
gamma
) every specified number of epochs (step_size
).
from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size=10, gamma=0.1) # Reduce lr by 10% every 10 epochs
MultiStepLR
- Example:
- Offers more control over when to decrease the learning rate.
- Decays the learning rate by a factor (
gamma
) at specific epochs defined in a list (milestones
).
from torch.optim.lr_scheduler import MultiStepLR
scheduler = MultiStepLR(optimizer, milestones=[20, 40], gamma=0.1) # Reduce lr at epochs 20 and 40
ReduceLROnPlateau
- Example:
- Adapts learning rate based on validation performance.
- Monitors a validation metric (e.g., validation loss) and reduces learning rate if it plateaus for a specified number of epochs (
patience
).
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5) # Reduce lr by 10% if validation loss doesn't improve in 5 epochs
CosineAnnealingLR
- Example:
- Often helpful for fine-tuning after initial learning rate decay.
- Gradually reduces the learning rate following a cosine curve from the initial value to a minimum (
eta_min
) over a specified number of epochs (T_max
).
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.001) # Cosine annealing with T_max=10 epochs and minimum lr of 0.001
CosineAnnealingWarmRestarts
- Can be useful for training models for very long durations.
- Similar to CosineAnnealingLR but restarts the cosine annealing cycle multiple times with a potentially lower
eta_min
each time.
OneCycleLR
- Can be effective for models that benefit from aggressive learning rate changes.
- Implements a one-cycle policy with a rising learning rate up to a peak, followed by a gradual decay to a minimum learning rate.
Choosing the Right Scheduler
The best scheduler depends on your specific problem and training dynamics. Here are some additional considerations:
- Training Experience
Experiment with different schedulers to find the one that works best for your task. - Model Architecture
Some models (e.g., CNNs) may be more sensitive to initial learning rate selection. - Problem Type
For simpler problems, StepLR or MultiStepLR might suffice. More complex problems might benefit from schedulers that adapt to validation performance or follow annealing strategies.